Update `.pre-commit-config.yaml` (#1026)

single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 9047d737f4
commit edd3ff1669
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,8 +1,5 @@
# Define hooks for code formations # Ultralytics YOLO 🚀, GPL-3.0 license
# Will be applied on any updated commit files if a user has installed and linked commit hook # Pre-commit hooks. For more information see https://github.com/pre-commit/pre-commit-hooks/blob/main/README.md
default_language_version:
python: python3.8
exclude: 'docs/' exclude: 'docs/'
# Define bot property if installed via https://github.com/marketplace/pre-commit-ci # Define bot property if installed via https://github.com/marketplace/pre-commit-ci
@ -16,13 +13,13 @@ repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0 rev: v4.4.0
hooks: hooks:
# - id: end-of-file-fixer - id: end-of-file-fixer
- id: trailing-whitespace - id: trailing-whitespace
- id: check-case-conflict - id: check-case-conflict
- id: check-yaml - id: check-yaml
- id: check-toml
- id: pretty-format-json
- id: check-docstring-first - id: check-docstring-first
- id: double-quote-string-fixer
- id: detect-private-key
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
rev: v3.3.1 rev: v3.3.1
@ -64,7 +61,7 @@ repos:
hooks: hooks:
- id: codespell - id: codespell
args: args:
- --ignore-words-list=crate,nd,strack - --ignore-words-list=crate,nd,strack,dota
#- repo: https://github.com/asottile/yesqa #- repo: https://github.com/asottile/yesqa
# rev: v1.4.0 # rev: v1.4.0

@ -31,8 +31,7 @@ RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics
# Install pip packages # Install pip packages
COPY requirements.txt . COPY requirements.txt .
RUN python3 -m pip install --upgrade pip wheel RUN python3 -m pip install --upgrade pip wheel
RUN pip install --no-cache ultralytics[export] albumentations comet gsutil notebook \ RUN pip install --no-cache ultralytics[export] albumentations comet gsutil notebook
# tensorflow tensorflowjs \
# Set environment variables # Set environment variables
ENV OMP_NUM_THREADS=1 ENV OMP_NUM_THREADS=1

@ -27,8 +27,6 @@ RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics
COPY requirements.txt . COPY requirements.txt .
RUN python3 -m pip install --upgrade pip wheel RUN python3 -m pip install --upgrade pip wheel
RUN pip install --no-cache ultralytics albumentations gsutil notebook RUN pip install --no-cache ultralytics albumentations gsutil notebook
# coremltools onnx onnxruntime \
# tensorflow-aarch64 tensorflowjs \
# Cleanup # Cleanup
ENV DEBIAN_FRONTEND teletype ENV DEBIAN_FRONTEND teletype

@ -27,8 +27,7 @@ RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics
COPY requirements.txt . COPY requirements.txt .
RUN python3 -m pip install --upgrade pip wheel RUN python3 -m pip install --upgrade pip wheel
RUN pip install --no-cache ultralytics[export] albumentations gsutil notebook \ RUN pip install --no-cache ultralytics[export] albumentations gsutil notebook \
# tensorflow-cpu tensorflowjs \ --extra-index-url https://download.pytorch.org/whl/cpu
--extra-index-url https://download.pytorch.org/whl/cpu
# Cleanup # Cleanup
ENV DEBIAN_FRONTEND teletype ENV DEBIAN_FRONTEND teletype

@ -92,7 +92,7 @@ Export a YOLOv8n model to a different format like ONNX, CoreML, etc.
## Overriding default arguments ## Overriding default arguments
Default arguments can be overriden by simply passing them as arguments in the CLI in `arg=value` pairs. Default arguments can be overridden by simply passing them as arguments in the CLI in `arg=value` pairs.
!!! tip "" !!! tip ""

@ -96,7 +96,7 @@ Class reference documentation for `Results` module and its components can be fou
## Visualizing results ## Visualizing results
You can use `visualize()` function of `Result` object to get a visualization. It plots all componenets(boxes, masks, classification logits, etc) found in the results object You can use `visualize()` function of `Result` object to get a visualization. It plots all components(boxes, masks, classification logits, etc) found in the results object
```python ```python
res = model(img) res = model(img)
res_plotted = res[0].visualize() res_plotted = res[0].visualize()

@ -2,7 +2,7 @@ The simplest way of simply using YOLOv8 directly in a Python environment.
!!! example "Train" !!! example "Train"
=== "From pretrained(recommanded)" === "From pretrained(recommended)"
```python ```python
from ultralytics import YOLO from ultralytics import YOLO

@ -16,7 +16,7 @@ PKG_REQUIREMENTS = ['sentry_sdk'] # pip-only requirements
def get_version(): def get_version():
file = PARENT / 'ultralytics/__init__.py' file = PARENT / 'ultralytics/__init__.py'
return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', file.read_text(encoding="utf-8"), re.M)[1] return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', file.read_text(encoding='utf-8'), re.M)[1]
setup( setup(

@ -49,9 +49,9 @@ def test_val_classify():
# Predict checks ------------------------------------------------------------------------------------------------------- # Predict checks -------------------------------------------------------------------------------------------------------
def test_predict_detect(): def test_predict_detect():
run(f"yolo predict model={MODEL}.pt source={ROOT / 'assets'} imgsz=32") run(f"yolo predict model={MODEL}.pt source={ROOT / 'assets'} imgsz=32")
run(f"yolo predict model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32") run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32')
run(f"yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=32") run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=32')
run(f"yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_portrait_min.mov imgsz=32") run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_portrait_min.mov imgsz=32')
def test_predict_segment(): def test_predict_segment():

@ -11,12 +11,12 @@ CFG_SEG = 'yolov8n-seg.yaml'
CFG_CLS = 'squeezenet1_0' CFG_CLS = 'squeezenet1_0'
CFG = get_cfg(DEFAULT_CFG) CFG = get_cfg(DEFAULT_CFG)
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n' MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n'
SOURCE = ROOT / "assets" SOURCE = ROOT / 'assets'
def test_detect(): def test_detect():
overrides = {"data": "coco8.yaml", "model": CFG_DET, "imgsz": 32, "epochs": 1, "save": False} overrides = {'data': 'coco8.yaml', 'model': CFG_DET, 'imgsz': 32, 'epochs': 1, 'save': False}
CFG.data = "coco8.yaml" CFG.data = 'coco8.yaml'
# Trainer # Trainer
trainer = detect.DetectionTrainer(overrides=overrides) trainer = detect.DetectionTrainer(overrides=overrides)
@ -27,24 +27,24 @@ def test_detect():
val(model=trainer.best) # validate best.pt val(model=trainer.best) # validate best.pt
# Predictor # Predictor
pred = detect.DetectionPredictor(overrides={"imgsz": [64, 64]}) pred = detect.DetectionPredictor(overrides={'imgsz': [64, 64]})
result = pred(source=SOURCE, model=f"{MODEL}.pt") result = pred(source=SOURCE, model=f'{MODEL}.pt')
assert len(result), "predictor test failed" assert len(result), 'predictor test failed'
overrides["resume"] = trainer.last overrides['resume'] = trainer.last
trainer = detect.DetectionTrainer(overrides=overrides) trainer = detect.DetectionTrainer(overrides=overrides)
try: try:
trainer.train() trainer.train()
except Exception as e: except Exception as e:
print(f"Expected exception caught: {e}") print(f'Expected exception caught: {e}')
return return
Exception("Resume test failed!") Exception('Resume test failed!')
def test_segment(): def test_segment():
overrides = {"data": "coco8-seg.yaml", "model": CFG_SEG, "imgsz": 32, "epochs": 1, "save": False} overrides = {'data': 'coco8-seg.yaml', 'model': CFG_SEG, 'imgsz': 32, 'epochs': 1, 'save': False}
CFG.data = "coco8-seg.yaml" CFG.data = 'coco8-seg.yaml'
CFG.v5loader = False CFG.v5loader = False
# YOLO(CFG_SEG).train(**overrides) # works # YOLO(CFG_SEG).train(**overrides) # works
@ -57,25 +57,25 @@ def test_segment():
val(model=trainer.best) # validate best.pt val(model=trainer.best) # validate best.pt
# Predictor # Predictor
pred = segment.SegmentationPredictor(overrides={"imgsz": [64, 64]}) pred = segment.SegmentationPredictor(overrides={'imgsz': [64, 64]})
result = pred(source=SOURCE, model=f"{MODEL}-seg.pt") result = pred(source=SOURCE, model=f'{MODEL}-seg.pt')
assert len(result) == 2, "predictor test failed" assert len(result) == 2, 'predictor test failed'
# Test resume # Test resume
overrides["resume"] = trainer.last overrides['resume'] = trainer.last
trainer = segment.SegmentationTrainer(overrides=overrides) trainer = segment.SegmentationTrainer(overrides=overrides)
try: try:
trainer.train() trainer.train()
except Exception as e: except Exception as e:
print(f"Expected exception caught: {e}") print(f'Expected exception caught: {e}')
return return
Exception("Resume test failed!") Exception('Resume test failed!')
def test_classify(): def test_classify():
overrides = {"data": "mnist160", "model": "yolov8n-cls.yaml", "imgsz": 32, "epochs": 1, "batch": 64, "save": False} overrides = {'data': 'mnist160', 'model': 'yolov8n-cls.yaml', 'imgsz': 32, 'epochs': 1, 'batch': 64, 'save': False}
CFG.data = "mnist160" CFG.data = 'mnist160'
CFG.imgsz = 32 CFG.imgsz = 32
CFG.batch = 64 CFG.batch = 64
# YOLO(CFG_SEG).train(**overrides) # works # YOLO(CFG_SEG).train(**overrides) # works
@ -89,6 +89,6 @@ def test_classify():
val(model=trainer.best) val(model=trainer.best)
# Predictor # Predictor
pred = classify.ClassificationPredictor(overrides={"imgsz": [64, 64]}) pred = classify.ClassificationPredictor(overrides={'imgsz': [64, 64]})
result = pred(source=SOURCE, model=trainer.best) result = pred(source=SOURCE, model=trainer.best)
assert len(result) == 2, "predictor test failed" assert len(result) == 2, 'predictor test failed'

@ -37,24 +37,24 @@ def test_model_fuse():
def test_predict_dir(): def test_predict_dir():
model = YOLO(MODEL) model = YOLO(MODEL)
model(source=ROOT / "assets") model(source=ROOT / 'assets')
def test_predict_img(): def test_predict_img():
model = YOLO(MODEL) model = YOLO(MODEL)
img = Image.open(str(SOURCE)) img = Image.open(str(SOURCE))
output = model(source=img, save=True, verbose=True) # PIL output = model(source=img, save=True, verbose=True) # PIL
assert len(output) == 1, "predict test failed" assert len(output) == 1, 'predict test failed'
img = cv2.imread(str(SOURCE)) img = cv2.imread(str(SOURCE))
output = model(source=img, save=True, save_txt=True) # ndarray output = model(source=img, save=True, save_txt=True) # ndarray
assert len(output) == 1, "predict test failed" assert len(output) == 1, 'predict test failed'
output = model(source=[img, img], save=True, save_txt=True) # batch output = model(source=[img, img], save=True, save_txt=True) # batch
assert len(output) == 2, "predict test failed" assert len(output) == 2, 'predict test failed'
output = model(source=[img, img], save=True, stream=True) # stream output = model(source=[img, img], save=True, stream=True) # stream
assert len(list(output)) == 2, "predict test failed" assert len(list(output)) == 2, 'predict test failed'
tens = torch.zeros(320, 640, 3) tens = torch.zeros(320, 640, 3)
output = model(tens.numpy()) output = model(tens.numpy())
assert len(output) == 1, "predict test failed" assert len(output) == 1, 'predict test failed'
# test multiple source # test multiple source
imgs = [ imgs = [
SOURCE, # filename SOURCE, # filename
@ -64,23 +64,23 @@ def test_predict_img():
Image.open(SOURCE), # PIL Image.open(SOURCE), # PIL
np.zeros((320, 640, 3))] # numpy np.zeros((320, 640, 3))] # numpy
output = model(imgs) output = model(imgs)
assert len(output) == 6, "predict test failed!" assert len(output) == 6, 'predict test failed!'
def test_val(): def test_val():
model = YOLO(MODEL) model = YOLO(MODEL)
model.val(data="coco8.yaml", imgsz=32) model.val(data='coco8.yaml', imgsz=32)
def test_train_scratch(): def test_train_scratch():
model = YOLO(CFG) model = YOLO(CFG)
model.train(data="coco8.yaml", epochs=1, imgsz=32) model.train(data='coco8.yaml', epochs=1, imgsz=32)
model(SOURCE) model(SOURCE)
def test_train_pretrained(): def test_train_pretrained():
model = YOLO(MODEL) model = YOLO(MODEL)
model.train(data="coco8.yaml", epochs=1, imgsz=32) model.train(data='coco8.yaml', epochs=1, imgsz=32)
model(SOURCE) model(SOURCE)
@ -139,10 +139,10 @@ def test_all_model_yamls():
def test_workflow(): def test_workflow():
model = YOLO(MODEL) model = YOLO(MODEL)
model.train(data="coco8.yaml", epochs=1, imgsz=32) model.train(data='coco8.yaml', epochs=1, imgsz=32)
model.val() model.val()
model.predict(SOURCE) model.predict(SOURCE)
model.export(format="onnx") # export a model to ONNX format model.export(format='onnx') # export a model to ONNX format
def test_predict_callback_and_setup(): def test_predict_callback_and_setup():
@ -154,8 +154,8 @@ def test_predict_callback_and_setup():
bs = [predictor.dataset.bs for _ in range(len(path))] bs = [predictor.dataset.bs for _ in range(len(path))]
predictor.results = zip(predictor.results, im0s, bs) predictor.results = zip(predictor.results, im0s, bs)
model = YOLO("yolov8n.pt") model = YOLO('yolov8n.pt')
model.add_callback("on_predict_batch_end", on_predict_batch_end) model.add_callback('on_predict_batch_end', on_predict_batch_end)
dataset = load_inference_source(source=SOURCE, transforms=model.transforms) dataset = load_inference_source(source=SOURCE, transforms=model.transforms)
bs = dataset.bs # noqa access predictor properties bs = dataset.bs # noqa access predictor properties
@ -168,8 +168,8 @@ def test_predict_callback_and_setup():
def test_result(): def test_result():
model = YOLO("yolov8n-seg.pt") model = YOLO('yolov8n-seg.pt')
img = str(ROOT / "assets/bus.jpg") img = str(ROOT / 'assets/bus.jpg')
res = model([img, img]) res = model([img, img])
res[0].numpy() res[0].numpy()
res[0].cpu().numpy() res[0].cpu().numpy()

@ -1,8 +1,8 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = "8.0.40" __version__ = '8.0.40'
from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils.checks import check_yolo as checks from ultralytics.yolo.utils.checks import check_yolo as checks
__all__ = ["__version__", "YOLO", "checks"] # allow simpler import __all__ = ['__version__', 'YOLO', 'checks'] # allow simpler import

@ -10,10 +10,10 @@ from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import LOGGER, PREFIX, emojis from ultralytics.yolo.utils import LOGGER, PREFIX, emojis
# Define all export formats # Define all export formats
EXPORT_FORMATS_HUB = EXPORT_FORMATS_LIST + ["ultralytics_tflite", "ultralytics_coreml"] EXPORT_FORMATS_HUB = EXPORT_FORMATS_LIST + ['ultralytics_tflite', 'ultralytics_coreml']
def start(key=""): def start(key=''):
""" """
Start training models with Ultralytics HUB. Usage: from src.ultralytics import start; start('API_KEY') Start training models with Ultralytics HUB. Usage: from src.ultralytics import start; start('API_KEY')
""" """
@ -34,7 +34,7 @@ def start(key=""):
session.register_callbacks(trainer) session.register_callbacks(trainer)
trainer.train(**session.train_args) trainer.train(**session.train_args)
except Exception as e: except Exception as e:
LOGGER.warning(f"{PREFIX}{e}") LOGGER.warning(f'{PREFIX}{e}')
def request_api_key(auth, max_attempts=3): def request_api_key(auth, max_attempts=3):
@ -43,56 +43,56 @@ def request_api_key(auth, max_attempts=3):
""" """
import getpass import getpass
for attempts in range(max_attempts): for attempts in range(max_attempts):
LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}") LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}')
input_key = getpass.getpass("Enter your Ultralytics HUB API key:\n") input_key = getpass.getpass('Enter your Ultralytics HUB API key:\n')
auth.api_key, model_id = split_key(input_key) auth.api_key, model_id = split_key(input_key)
if auth.authenticate(): if auth.authenticate():
LOGGER.info(f"{PREFIX}Authenticated ✅") LOGGER.info(f'{PREFIX}Authenticated ✅')
return model_id return model_id
LOGGER.warning(f"{PREFIX}Invalid API key ⚠️\n") LOGGER.warning(f'{PREFIX}Invalid API key ⚠️\n')
raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌")) raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌'))
def reset_model(key=""): def reset_model(key=''):
# Reset a trained model to an untrained state # Reset a trained model to an untrained state
api_key, model_id = split_key(key) api_key, model_id = split_key(key)
r = requests.post("https://api.ultralytics.com/model-reset", json={"apiKey": api_key, "modelId": model_id}) r = requests.post('https://api.ultralytics.com/model-reset', json={'apiKey': api_key, 'modelId': model_id})
if r.status_code == 200: if r.status_code == 200:
LOGGER.info(f"{PREFIX}model reset successfully") LOGGER.info(f'{PREFIX}model reset successfully')
return return
LOGGER.warning(f"{PREFIX}model reset failure {r.status_code} {r.reason}") LOGGER.warning(f'{PREFIX}model reset failure {r.status_code} {r.reason}')
def export_model(key="", format="torchscript"): def export_model(key='', format='torchscript'):
# Export a model to all formats # Export a model to all formats
assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}" assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}"
api_key, model_id = split_key(key) api_key, model_id = split_key(key)
r = requests.post("https://api.ultralytics.com/export", r = requests.post('https://api.ultralytics.com/export',
json={ json={
"apiKey": api_key, 'apiKey': api_key,
"modelId": model_id, 'modelId': model_id,
"format": format}) 'format': format})
assert (r.status_code == 200), f"{PREFIX}{format} export failure {r.status_code} {r.reason}" assert (r.status_code == 200), f'{PREFIX}{format} export failure {r.status_code} {r.reason}'
LOGGER.info(f"{PREFIX}{format} export started ✅") LOGGER.info(f'{PREFIX}{format} export started ✅')
def get_export(key="", format="torchscript"): def get_export(key='', format='torchscript'):
# Get an exported model dictionary with download URL # Get an exported model dictionary with download URL
assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}" assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}"
api_key, model_id = split_key(key) api_key, model_id = split_key(key)
r = requests.post("https://api.ultralytics.com/get-export", r = requests.post('https://api.ultralytics.com/get-export',
json={ json={
"apiKey": api_key, 'apiKey': api_key,
"modelId": model_id, 'modelId': model_id,
"format": format}) 'format': format})
assert (r.status_code == 200), f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}" assert (r.status_code == 200), f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}'
return r.json() return r.json()
# temp. For checking # temp. For checking
if __name__ == "__main__": if __name__ == '__main__':
start() start()

@ -5,7 +5,7 @@ import requests
from ultralytics.hub.utils import HUB_API_ROOT, request_with_credentials from ultralytics.hub.utils import HUB_API_ROOT, request_with_credentials
from ultralytics.yolo.utils import is_colab from ultralytics.yolo.utils import is_colab
API_KEY_PATH = "https://hub.ultralytics.com/settings?tab=api+keys" API_KEY_PATH = 'https://hub.ultralytics.com/settings?tab=api+keys'
class Auth: class Auth:
@ -18,7 +18,7 @@ class Auth:
@staticmethod @staticmethod
def _clean_api_key(key: str) -> str: def _clean_api_key(key: str) -> str:
"""Strip model from key if present""" """Strip model from key if present"""
separator = "_" separator = '_'
return key.split(separator)[0] if separator in key else key return key.split(separator)[0] if separator in key else key
def authenticate(self) -> bool: def authenticate(self) -> bool:
@ -26,11 +26,11 @@ class Auth:
try: try:
header = self.get_auth_header() header = self.get_auth_header()
if header: if header:
r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header) r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header)
if not r.json().get('success', False): if not r.json().get('success', False):
raise ConnectionError("Unable to authenticate.") raise ConnectionError('Unable to authenticate.')
return True return True
raise ConnectionError("User has not authenticated locally.") raise ConnectionError('User has not authenticated locally.')
except ConnectionError: except ConnectionError:
self.id_token = self.api_key = False # reset invalid self.id_token = self.api_key = False # reset invalid
return False return False
@ -43,21 +43,21 @@ class Auth:
if not is_colab(): if not is_colab():
return False # Currently only works with Colab return False # Currently only works with Colab
try: try:
authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto") authn = request_with_credentials(f'{HUB_API_ROOT}/v1/auth/auto')
if authn.get("success", False): if authn.get('success', False):
self.id_token = authn.get("data", {}).get("idToken", None) self.id_token = authn.get('data', {}).get('idToken', None)
self.authenticate() self.authenticate()
return True return True
raise ConnectionError("Unable to fetch browser authentication details.") raise ConnectionError('Unable to fetch browser authentication details.')
except ConnectionError: except ConnectionError:
self.id_token = False # reset invalid self.id_token = False # reset invalid
return False return False
def get_auth_header(self): def get_auth_header(self):
if self.id_token: if self.id_token:
return {"authorization": f"Bearer {self.id_token}"} return {'authorization': f'Bearer {self.id_token}'}
elif self.api_key: elif self.api_key:
return {"x-api-key": self.api_key} return {'x-api-key': self.api_key}
else: else:
return None return None

@ -11,7 +11,7 @@ from ultralytics.hub.utils import HUB_API_ROOT, check_dataset_disk_space, smart_
from ultralytics.yolo.utils import LOGGER, PREFIX, __version__, emojis, is_colab, threaded from ultralytics.yolo.utils import LOGGER, PREFIX, __version__, emojis, is_colab, threaded
from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
AGENT_NAME = f"python-{__version__}-colab" if is_colab() else f"python-{__version__}-local" AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
session = None session = None
@ -20,9 +20,9 @@ class HubTrainingSession:
def __init__(self, model_id, auth): def __init__(self, model_id, auth):
self.agent_id = None # identifies which instance is communicating with server self.agent_id = None # identifies which instance is communicating with server
self.model_id = model_id self.model_id = model_id
self.api_url = f"{HUB_API_ROOT}/v1/models/{model_id}" self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
self.auth_header = auth.get_auth_header() self.auth_header = auth.get_auth_header()
self._rate_limits = {"metrics": 3.0, "ckpt": 900.0, "heartbeat": 300.0} # rate limits (seconds) self._rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
self._timers = {} # rate limit timers (seconds) self._timers = {} # rate limit timers (seconds)
self._metrics_queue = {} # metrics queue self._metrics_queue = {} # metrics queue
self.model = self._get_model() self.model = self._get_model()
@ -40,7 +40,7 @@ class HubTrainingSession:
passed by signal. passed by signal.
""" """
if self.alive is True: if self.alive is True:
LOGGER.info(f"{PREFIX}Kill signal received! ❌") LOGGER.info(f'{PREFIX}Kill signal received! ❌')
self._stop_heartbeat() self._stop_heartbeat()
sys.exit(signum) sys.exit(signum)
@ -49,23 +49,23 @@ class HubTrainingSession:
self.alive = False self.alive = False
def upload_metrics(self): def upload_metrics(self):
payload = {"metrics": self._metrics_queue.copy(), "type": "metrics"} payload = {'metrics': self._metrics_queue.copy(), 'type': 'metrics'}
smart_request(f"{self.api_url}", json=payload, headers=self.auth_header, code=2) smart_request(f'{self.api_url}', json=payload, headers=self.auth_header, code=2)
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False): def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
# Upload a model to HUB # Upload a model to HUB
file = None file = None
if Path(weights).is_file(): if Path(weights).is_file():
with open(weights, "rb") as f: with open(weights, 'rb') as f:
file = f.read() file = f.read()
if final: if final:
smart_request( smart_request(
f"{self.api_url}/upload", f'{self.api_url}/upload',
data={ data={
"epoch": epoch, 'epoch': epoch,
"type": "final", 'type': 'final',
"map": map}, 'map': map},
files={"best.pt": file}, files={'best.pt': file},
headers=self.auth_header, headers=self.auth_header,
retry=10, retry=10,
timeout=3600, timeout=3600,
@ -73,66 +73,66 @@ class HubTrainingSession:
) )
else: else:
smart_request( smart_request(
f"{self.api_url}/upload", f'{self.api_url}/upload',
data={ data={
"epoch": epoch, 'epoch': epoch,
"type": "epoch", 'type': 'epoch',
"isBest": bool(is_best)}, 'isBest': bool(is_best)},
headers=self.auth_header, headers=self.auth_header,
files={"last.pt": file}, files={'last.pt': file},
code=3, code=3,
) )
def _get_model(self): def _get_model(self):
# Returns model from database by id # Returns model from database by id
api_url = f"{HUB_API_ROOT}/v1/models/{self.model_id}" api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
headers = self.auth_header headers = self.auth_header
try: try:
response = smart_request(api_url, method="get", headers=headers, thread=False, code=0) response = smart_request(api_url, method='get', headers=headers, thread=False, code=0)
data = response.json().get("data", None) data = response.json().get('data', None)
if data.get("status", None) == "trained": if data.get('status', None) == 'trained':
raise ValueError( raise ValueError(
emojis(f"Model is already trained and uploaded to " emojis(f'Model is already trained and uploaded to '
f"https://hub.ultralytics.com/models/{self.model_id} 🚀")) f'https://hub.ultralytics.com/models/{self.model_id} 🚀'))
if not data.get("data", None): if not data.get('data', None):
raise ValueError("Dataset may still be processing. Please wait a minute and try again.") # RF fix raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
self.model_id = data["id"] self.model_id = data['id']
# TODO: restore when server keys when dataset URL and GPU train is working # TODO: restore when server keys when dataset URL and GPU train is working
self.train_args = { self.train_args = {
"batch": data["batch_size"], 'batch': data['batch_size'],
"epochs": data["epochs"], 'epochs': data['epochs'],
"imgsz": data["imgsz"], 'imgsz': data['imgsz'],
"patience": data["patience"], 'patience': data['patience'],
"device": data["device"], 'device': data['device'],
"cache": data["cache"], 'cache': data['cache'],
"data": data["data"]} 'data': data['data']}
self.input_file = data.get("cfg", data["weights"]) self.input_file = data.get('cfg', data['weights'])
# hack for yolov5 cfg adds u # hack for yolov5 cfg adds u
if "cfg" in data and "yolov5" in data["cfg"]: if 'cfg' in data and 'yolov5' in data['cfg']:
self.input_file = data["cfg"].replace(".yaml", "u.yaml") self.input_file = data['cfg'].replace('.yaml', 'u.yaml')
return data return data
except requests.exceptions.ConnectionError as e: except requests.exceptions.ConnectionError as e:
raise ConnectionRefusedError("ERROR: The HUB server is not online. Please try again later.") from e raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e
except Exception: except Exception:
raise raise
def check_disk_space(self): def check_disk_space(self):
if not check_dataset_disk_space(self.model["data"]): if not check_dataset_disk_space(self.model['data']):
raise MemoryError("Not enough disk space") raise MemoryError('Not enough disk space')
def register_callbacks(self, trainer): def register_callbacks(self, trainer):
trainer.add_callback("on_pretrain_routine_end", self.on_pretrain_routine_end) trainer.add_callback('on_pretrain_routine_end', self.on_pretrain_routine_end)
trainer.add_callback("on_fit_epoch_end", self.on_fit_epoch_end) trainer.add_callback('on_fit_epoch_end', self.on_fit_epoch_end)
trainer.add_callback("on_model_save", self.on_model_save) trainer.add_callback('on_model_save', self.on_model_save)
trainer.add_callback("on_train_end", self.on_train_end) trainer.add_callback('on_train_end', self.on_train_end)
def on_pretrain_routine_end(self, trainer): def on_pretrain_routine_end(self, trainer):
""" """
@ -140,57 +140,57 @@ class HubTrainingSession:
This method does not use trainer. It is passed to all callbacks by default. This method does not use trainer. It is passed to all callbacks by default.
""" """
# Start timer for upload rate limit # Start timer for upload rate limit
LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀") LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀')
self._timers = {"metrics": time(), "ckpt": time()} # start timer on self.rate_limit self._timers = {'metrics': time(), 'ckpt': time()} # start timer on self.rate_limit
def on_fit_epoch_end(self, trainer): def on_fit_epoch_end(self, trainer):
# Upload metrics after val end # Upload metrics after val end
all_plots = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics} all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics}
if trainer.epoch == 0: if trainer.epoch == 0:
model_info = { model_info = {
"model/parameters": get_num_params(trainer.model), 'model/parameters': get_num_params(trainer.model),
"model/GFLOPs": round(get_flops(trainer.model), 3), 'model/GFLOPs': round(get_flops(trainer.model), 3),
"model/speed(ms)": round(trainer.validator.speed[1], 3)} 'model/speed(ms)': round(trainer.validator.speed[1], 3)}
all_plots = {**all_plots, **model_info} all_plots = {**all_plots, **model_info}
self._metrics_queue[trainer.epoch] = json.dumps(all_plots) self._metrics_queue[trainer.epoch] = json.dumps(all_plots)
if time() - self._timers["metrics"] > self._rate_limits["metrics"]: if time() - self._timers['metrics'] > self._rate_limits['metrics']:
self.upload_metrics() self.upload_metrics()
self._timers["metrics"] = time() # reset timer self._timers['metrics'] = time() # reset timer
self._metrics_queue = {} # reset queue self._metrics_queue = {} # reset queue
def on_model_save(self, trainer): def on_model_save(self, trainer):
# Upload checkpoints with rate limiting # Upload checkpoints with rate limiting
is_best = trainer.best_fitness == trainer.fitness is_best = trainer.best_fitness == trainer.fitness
if time() - self._timers["ckpt"] > self._rate_limits["ckpt"]: if time() - self._timers['ckpt'] > self._rate_limits['ckpt']:
LOGGER.info(f"{PREFIX}Uploading checkpoint {self.model_id}") LOGGER.info(f'{PREFIX}Uploading checkpoint {self.model_id}')
self._upload_model(trainer.epoch, trainer.last, is_best) self._upload_model(trainer.epoch, trainer.last, is_best)
self._timers["ckpt"] = time() # reset timer self._timers['ckpt'] = time() # reset timer
def on_train_end(self, trainer): def on_train_end(self, trainer):
# Upload final model and metrics with exponential standoff # Upload final model and metrics with exponential standoff
LOGGER.info(f"{PREFIX}Training completed successfully ✅") LOGGER.info(f'{PREFIX}Training completed successfully ✅')
LOGGER.info(f"{PREFIX}Uploading final {self.model_id}") LOGGER.info(f'{PREFIX}Uploading final {self.model_id}')
# hack for fetching mAP # hack for fetching mAP
mAP = trainer.metrics.get("metrics/mAP50-95(B)", 0) mAP = trainer.metrics.get('metrics/mAP50-95(B)', 0)
self._upload_model(trainer.epoch, trainer.best, map=mAP, final=True) # results[3] is mAP0.5:0.95 self._upload_model(trainer.epoch, trainer.best, map=mAP, final=True) # results[3] is mAP0.5:0.95
self.alive = False # stop heartbeats self.alive = False # stop heartbeats
LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀") LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀')
def _upload_model(self, epoch, weights, is_best=False, map=0.0, final=False): def _upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
# Upload a model to HUB # Upload a model to HUB
file = None file = None
if Path(weights).is_file(): if Path(weights).is_file():
with open(weights, "rb") as f: with open(weights, 'rb') as f:
file = f.read() file = f.read()
file_param = {"best.pt" if final else "last.pt": file} file_param = {'best.pt' if final else 'last.pt': file}
endpoint = f"{self.api_url}/upload" endpoint = f'{self.api_url}/upload'
data = {"epoch": epoch} data = {'epoch': epoch}
if final: if final:
data.update({"type": "final", "map": map}) data.update({'type': 'final', 'map': map})
else: else:
data.update({"type": "epoch", "isBest": bool(is_best)}) data.update({'type': 'epoch', 'isBest': bool(is_best)})
smart_request( smart_request(
endpoint, endpoint,
@ -207,14 +207,14 @@ class HubTrainingSession:
self.alive = True self.alive = True
while self.alive: while self.alive:
r = smart_request( r = smart_request(
f"{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}", f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
json={ json={
"agent": AGENT_NAME, 'agent': AGENT_NAME,
"agentId": self.agent_id}, 'agentId': self.agent_id},
headers=self.auth_header, headers=self.auth_header,
retry=0, retry=0,
code=5, code=5,
thread=False, thread=False,
) )
self.agent_id = r.json().get("data", {}).get("agentId", None) self.agent_id = r.json().get('data', {}).get('agentId', None)
sleep(self._rate_limits["heartbeat"]) sleep(self._rate_limits['heartbeat'])

@ -18,14 +18,14 @@ from ultralytics.yolo.utils.checks import check_online
PREFIX = colorstr('Ultralytics: ') PREFIX = colorstr('Ultralytics: ')
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.' HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
HUB_API_ROOT = os.environ.get("ULTRALYTICS_HUB_API", "https://api.ultralytics.com") HUB_API_ROOT = os.environ.get('ULTRALYTICS_HUB_API', 'https://api.ultralytics.com')
def check_dataset_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=2.0): def check_dataset_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=2.0):
# Check that url fits on disk with safety factor sf, i.e. require 2GB free if url size is 1GB with sf=2.0 # Check that url fits on disk with safety factor sf, i.e. require 2GB free if url size is 1GB with sf=2.0
gib = 1 << 30 # bytes per GiB gib = 1 << 30 # bytes per GiB
data = int(requests.head(url).headers['Content-Length']) / gib # dataset size (GB) data = int(requests.head(url).headers['Content-Length']) / gib # dataset size (GB)
total, used, free = (x / gib for x in shutil.disk_usage("/")) # bytes total, used, free = (x / gib for x in shutil.disk_usage('/')) # bytes
LOGGER.info(f'{PREFIX}{data:.3f} GB dataset, {free:.1f}/{total:.1f} GB free disk space') LOGGER.info(f'{PREFIX}{data:.3f} GB dataset, {free:.1f}/{total:.1f} GB free disk space')
if data * sf < free: if data * sf < free:
return True # sufficient space return True # sufficient space
@ -57,7 +57,7 @@ def request_with_credentials(url: str) -> any:
}); });
}); });
""" % url)) """ % url))
return output.eval_js("_hub_tmp") return output.eval_js('_hub_tmp')
# Deprecated TODO: eliminate this function? # Deprecated TODO: eliminate this function?
@ -84,7 +84,7 @@ def split_key(key=''):
return api_key, model_id return api_key, model_id
def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method="post", verbose=True, **kwargs): def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method='post', verbose=True, **kwargs):
""" """
Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout. Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
@ -128,7 +128,7 @@ def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method="post
m = f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " \ m = f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " \
f"Please retry after {h['Retry-After']}s." f"Please retry after {h['Retry-After']}s."
if verbose: if verbose:
LOGGER.warning(f"{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})") LOGGER.warning(f'{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})')
if r.status_code not in retry_codes: if r.status_code not in retry_codes:
return r return r
time.sleep(2 ** i) # exponential standoff time.sleep(2 ** i) # exponential standoff
@ -149,17 +149,17 @@ class Traces:
self.rate_limit = 3.0 # rate limit (seconds) self.rate_limit = 3.0 # rate limit (seconds)
self.t = 0.0 # rate limit timer (seconds) self.t = 0.0 # rate limit timer (seconds)
self.metadata = { self.metadata = {
"sys_argv_name": Path(sys.argv[0]).name, 'sys_argv_name': Path(sys.argv[0]).name,
"install": 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other', 'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
"python": platform.python_version(), 'python': platform.python_version(),
"release": __version__, 'release': __version__,
"environment": ENVIRONMENT} 'environment': ENVIRONMENT}
self.enabled = SETTINGS['sync'] and \ self.enabled = SETTINGS['sync'] and \
RANK in {-1, 0} and \ RANK in {-1, 0} and \
check_online() and \ check_online() and \
not is_pytest_running() and \ not is_pytest_running() and \
not is_github_actions_ci() and \ not is_github_actions_ci() and \
(is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git") (is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git')
def __call__(self, cfg, all_keys=False, traces_sample_rate=1.0): def __call__(self, cfg, all_keys=False, traces_sample_rate=1.0):
""" """

@ -41,4 +41,4 @@ head:
[-1, 3, C3, [1024, False]], # 23 (P5/32-large) [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
[[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5) [[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5)
] ]

@ -41,4 +41,4 @@ head:
[-1, 3, C3, [1024, False]], # 23 (P5/32-large) [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
[[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5) [[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5)
] ]

@ -41,4 +41,4 @@ head:
[-1, 3, C3, [1024, False]], # 23 (P5/32-large) [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
[[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5) [[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5)
] ]

@ -42,4 +42,4 @@ head:
[-1, 3, C3, [1024, False]], # 23 (P5/32-large) [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
[[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5) [[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5)
] ]

@ -41,4 +41,4 @@ head:
[-1, 3, C3, [1024, False]], # 23 (P5/32-large) [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
[[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5) [[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5)
] ]

@ -127,11 +127,11 @@ class AutoBackend(nn.Module):
w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin')) network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin'))
if network.get_parameters()[0].get_layout().empty: if network.get_parameters()[0].get_layout().empty:
network.get_parameters()[0].set_layout(Layout("NCHW")) network.get_parameters()[0].set_layout(Layout('NCHW'))
batch_dim = get_batch(network) batch_dim = get_batch(network)
if batch_dim.is_static: if batch_dim.is_static:
batch_size = batch_dim.get_length() batch_size = batch_dim.get_length()
executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2 executable_network = ie.compile_model(network, device_name='CPU') # device_name="MYRIAD" for Intel NCS2
elif engine: # TensorRT elif engine: # TensorRT
LOGGER.info(f'Loading {w} for TensorRT inference...') LOGGER.info(f'Loading {w} for TensorRT inference...')
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
@ -184,7 +184,7 @@ class AutoBackend(nn.Module):
import tensorflow as tf import tensorflow as tf
def wrap_frozen_graph(gd, inputs, outputs): def wrap_frozen_graph(gd, inputs, outputs):
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) # wrapped
ge = x.graph.as_graph_element ge = x.graph.as_graph_element
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs)) return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
@ -198,7 +198,7 @@ class AutoBackend(nn.Module):
gd = tf.Graph().as_graph_def() # TF GraphDef gd = tf.Graph().as_graph_def() # TF GraphDef
with open(w, 'rb') as f: with open(w, 'rb') as f:
gd.ParseFromString(f.read()) gd.ParseFromString(f.read())
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd)) frozen_func = wrap_frozen_graph(gd, inputs='x:0', outputs=gd_outputs(gd))
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
from tflite_runtime.interpreter import Interpreter, load_delegate from tflite_runtime.interpreter import Interpreter, load_delegate
@ -220,9 +220,9 @@ class AutoBackend(nn.Module):
output_details = interpreter.get_output_details() # outputs output_details = interpreter.get_output_details() # outputs
# load metadata # load metadata
with contextlib.suppress(zipfile.BadZipFile): with contextlib.suppress(zipfile.BadZipFile):
with zipfile.ZipFile(w, "r") as model: with zipfile.ZipFile(w, 'r') as model:
meta_file = model.namelist()[0] meta_file = model.namelist()[0]
meta = ast.literal_eval(model.read(meta_file).decode("utf-8")) meta = ast.literal_eval(model.read(meta_file).decode('utf-8'))
stride, names = int(meta['stride']), meta['names'] stride, names = int(meta['stride']), meta['names']
elif tfjs: # TF.js elif tfjs: # TF.js
raise NotImplementedError('YOLOv8 TF.js inference is not supported') raise NotImplementedError('YOLOv8 TF.js inference is not supported')
@ -251,8 +251,8 @@ class AutoBackend(nn.Module):
else: else:
from ultralytics.yolo.engine.exporter import EXPORT_FORMATS_TABLE from ultralytics.yolo.engine.exporter import EXPORT_FORMATS_TABLE
raise TypeError(f"model='{w}' is not a supported model format. " raise TypeError(f"model='{w}' is not a supported model format. "
"See https://docs.ultralytics.com/tasks/detection/#export for help." 'See https://docs.ultralytics.com/tasks/detection/#export for help.'
f"\n\n{EXPORT_FORMATS_TABLE}") f'\n\n{EXPORT_FORMATS_TABLE}')
# Load external metadata YAML # Load external metadata YAML
if xml or saved_model or paddle: if xml or saved_model or paddle:
@ -410,5 +410,5 @@ class AutoBackend(nn.Module):
url = urlparse(p) # if url may be Triton inference server url = urlparse(p) # if url may be Triton inference server
types = [s in Path(p).name for s in sf] types = [s in Path(p).name for s in sf]
types[8] &= not types[9] # tflite &= not edgetpu types[8] &= not types[9] # tflite &= not edgetpu
triton = not any(types) and all([any(s in url.scheme for s in ["http", "grpc"]), url.netloc]) triton = not any(types) and all([any(s in url.scheme for s in ['http', 'grpc']), url.netloc])
return types + [triton] return types + [triton]

@ -99,7 +99,7 @@ class AutoShape(nn.Module):
shape1.append([y * g for y in s]) shape1.append([y * g for y in s])
ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] if self.pt else size # inf shape shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] if self.pt else size # inf shape
x = [LetterBox(shape1, auto=False)(image=im)["img"] for im in ims] # pad x = [LetterBox(shape1, auto=False)(image=im)['img'] for im in ims] # pad
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32 x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32

@ -160,7 +160,7 @@ class BaseModel(nn.Module):
weights (str): The weights to load into the model. weights (str): The weights to load into the model.
""" """
# Force all tasks to implement this function # Force all tasks to implement this function
raise NotImplementedError("This function needs to be implemented by derived classes!") raise NotImplementedError('This function needs to be implemented by derived classes!')
class DetectionModel(BaseModel): class DetectionModel(BaseModel):
@ -249,7 +249,7 @@ class SegmentationModel(DetectionModel):
super().__init__(cfg, ch, nc, verbose) super().__init__(cfg, ch, nc, verbose)
def _forward_augment(self, x): def _forward_augment(self, x):
raise NotImplementedError("WARNING ⚠️ SegmentationModel has not supported augment inference yet!") raise NotImplementedError('WARNING ⚠️ SegmentationModel has not supported augment inference yet!')
class ClassificationModel(BaseModel): class ClassificationModel(BaseModel):
@ -292,7 +292,7 @@ class ClassificationModel(BaseModel):
self.info() self.info()
def load(self, weights): def load(self, weights):
model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts model = weights['model'] if isinstance(weights, dict) else weights # torchvision models are not dicts
csd = model.float().state_dict() csd = model.float().state_dict()
csd = intersect_dicts(csd, self.state_dict()) # intersect csd = intersect_dicts(csd, self.state_dict()) # intersect
self.load_state_dict(csd, strict=False) # load self.load_state_dict(csd, strict=False) # load
@ -341,10 +341,10 @@ def torch_safe_load(weight):
return torch.load(file, map_location='cpu') # load return torch.load(file, map_location='cpu') # load
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
if e.name == 'omegaconf': # e.name is missing module name if e.name == 'omegaconf': # e.name is missing module name
LOGGER.warning(f"WARNING ⚠️ {weight} requires {e.name}, which is not in ultralytics requirements." LOGGER.warning(f'WARNING ⚠️ {weight} requires {e.name}, which is not in ultralytics requirements.'
f"\nAutoInstall will run now for {e.name} but this feature will be removed in the future." f'\nAutoInstall will run now for {e.name} but this feature will be removed in the future.'
f"\nRecommend fixes are to train a new model using updated ultralytics package or to " f'\nRecommend fixes are to train a new model using updated ultralytics package or to '
f"download updated models from https://github.com/ultralytics/assets/releases/tag/v0.0.0") f'download updated models from https://github.com/ultralytics/assets/releases/tag/v0.0.0')
if e.name != 'models': if e.name != 'models':
check_requirements(e.name) # install missing module check_requirements(e.name) # install missing module
return torch.load(file, map_location='cpu') # load return torch.load(file, map_location='cpu') # load
@ -489,13 +489,13 @@ def guess_model_task(model):
def cfg2task(cfg): def cfg2task(cfg):
# Guess from YAML dictionary # Guess from YAML dictionary
m = cfg["head"][-1][-2].lower() # output module name m = cfg['head'][-1][-2].lower() # output module name
if m in ["classify", "classifier", "cls", "fc"]: if m in ['classify', 'classifier', 'cls', 'fc']:
return "classify" return 'classify'
if m in ["detect"]: if m in ['detect']:
return "detect" return 'detect'
if m in ["segment"]: if m in ['segment']:
return "segment" return 'segment'
# Guess from model cfg # Guess from model cfg
if isinstance(model, dict): if isinstance(model, dict):
@ -513,22 +513,22 @@ def guess_model_task(model):
for m in model.modules(): for m in model.modules():
if isinstance(m, Detect): if isinstance(m, Detect):
return "detect" return 'detect'
elif isinstance(m, Segment): elif isinstance(m, Segment):
return "segment" return 'segment'
elif isinstance(m, Classify): elif isinstance(m, Classify):
return "classify" return 'classify'
# Guess from model filename # Guess from model filename
if isinstance(model, (str, Path)): if isinstance(model, (str, Path)):
model = Path(model).stem model = Path(model).stem
if '-seg' in model: if '-seg' in model:
return "segment" return 'segment'
elif '-cls' in model: elif '-cls' in model:
return "classify" return 'classify'
else: else:
return "detect" return 'detect'
# Unable to determine task from model # Unable to determine task from model
raise SyntaxError("YOLO is unable to automatically guess model task. Explicitly define task for your model, " raise SyntaxError('YOLO is unable to automatically guess model task. Explicitly define task for your model, '
"i.e. 'task=detect', 'task=segment' or 'task=classify'.") "i.e. 'task=detect', 'task=segment' or 'task=classify'.")

@ -4,14 +4,14 @@ from ultralytics.tracker import BOTSORT, BYTETracker
from ultralytics.yolo.utils import IterableSimpleNamespace, yaml_load from ultralytics.yolo.utils import IterableSimpleNamespace, yaml_load
from ultralytics.yolo.utils.checks import check_requirements, check_yaml from ultralytics.yolo.utils.checks import check_requirements, check_yaml
TRACKER_MAP = {"bytetrack": BYTETracker, "botsort": BOTSORT} TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT}
check_requirements('lap') # for linear_assignment check_requirements('lap') # for linear_assignment
def on_predict_start(predictor): def on_predict_start(predictor):
tracker = check_yaml(predictor.args.tracker) tracker = check_yaml(predictor.args.tracker)
cfg = IterableSimpleNamespace(**yaml_load(tracker)) cfg = IterableSimpleNamespace(**yaml_load(tracker))
assert cfg.tracker_type in ["bytetrack", "botsort"], \ assert cfg.tracker_type in ['bytetrack', 'botsort'], \
f"Only support 'bytetrack' and 'botsort' for now, but got '{cfg.tracker_type}'" f"Only support 'bytetrack' and 'botsort' for now, but got '{cfg.tracker_type}'"
trackers = [] trackers = []
for _ in range(predictor.dataset.bs): for _ in range(predictor.dataset.bs):
@ -38,5 +38,5 @@ def on_predict_postprocess_end(predictor):
def register_tracker(model): def register_tracker(model):
model.add_callback("on_predict_start", on_predict_start) model.add_callback('on_predict_start', on_predict_start)
model.add_callback("on_predict_postprocess_end", on_predict_postprocess_end) model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end)

@ -153,7 +153,7 @@ class STrack(BaseTrack):
return ret return ret
def __repr__(self): def __repr__(self):
return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})" return f'OT_{self.track_id}_({self.start_frame}-{self.end_frame})'
class BYTETracker: class BYTETracker:
@ -206,7 +206,7 @@ class BYTETracker:
strack_pool = self.joint_stracks(tracked_stracks, self.lost_stracks) strack_pool = self.joint_stracks(tracked_stracks, self.lost_stracks)
# Predict the current location with KF # Predict the current location with KF
self.multi_predict(strack_pool) self.multi_predict(strack_pool)
if hasattr(self, "gmc"): if hasattr(self, 'gmc'):
warp = self.gmc.apply(img, dets) warp = self.gmc.apply(img, dets)
STrack.multi_gmc(strack_pool, warp) STrack.multi_gmc(strack_pool, warp)
STrack.multi_gmc(unconfirmed, warp) STrack.multi_gmc(unconfirmed, warp)

@ -50,14 +50,14 @@ class GMC:
seqName = seqName[:-6] seqName = seqName[:-6]
elif '-DPM' in seqName or '-SDP' in seqName: elif '-DPM' in seqName or '-SDP' in seqName:
seqName = seqName[:-4] seqName = seqName[:-4]
self.gmcFile = open(f"{filePath}/GMC-{seqName}.txt") self.gmcFile = open(f'{filePath}/GMC-{seqName}.txt')
if self.gmcFile is None: if self.gmcFile is None:
raise ValueError(f"Error: Unable to open GMC file in directory:{filePath}") raise ValueError(f'Error: Unable to open GMC file in directory:{filePath}')
elif self.method in ['none', 'None']: elif self.method in ['none', 'None']:
self.method = 'none' self.method = 'none'
else: else:
raise ValueError(f"Error: Unknown CMC method:{method}") raise ValueError(f'Error: Unknown CMC method:{method}')
self.prevFrame = None self.prevFrame = None
self.prevKeyPoints = None self.prevKeyPoints = None
@ -302,7 +302,7 @@ class GMC:
def applyFile(self, raw_frame, detections=None): def applyFile(self, raw_frame, detections=None):
line = self.gmcFile.readline() line = self.gmcFile.readline()
tokens = line.split("\t") tokens = line.split('\t')
H = np.eye(2, 3, dtype=np.float_) H = np.eye(2, 3, dtype=np.float_)
H[0, 0] = float(tokens[1]) H[0, 0] = float(tokens[1])
H[0, 1] = float(tokens[2]) H[0, 1] = float(tokens[2])

@ -2,4 +2,4 @@
from . import v8 from . import v8
__all__ = ["v8"] __all__ = ['v8']

@ -142,8 +142,8 @@ def check_cfg_mismatch(base: Dict, custom: Dict, e=None):
string = '' string = ''
for x in mismatched: for x in mismatched:
matches = get_close_matches(x, base) # key list matches = get_close_matches(x, base) # key list
matches = [f"{k}={DEFAULT_CFG_DICT[k]}" if DEFAULT_CFG_DICT.get(k) is not None else k for k in matches] matches = [f'{k}={DEFAULT_CFG_DICT[k]}' if DEFAULT_CFG_DICT.get(k) is not None else k for k in matches]
match_str = f"Similar arguments are i.e. {matches}." if matches else '' match_str = f'Similar arguments are i.e. {matches}.' if matches else ''
string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n" string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
raise SyntaxError(string + CLI_HELP_MSG) from e raise SyntaxError(string + CLI_HELP_MSG) from e
@ -163,10 +163,10 @@ def merge_equals_args(args: List[str]) -> List[str]:
new_args = [] new_args = []
for i, arg in enumerate(args): for i, arg in enumerate(args):
if arg == '=' and 0 < i < len(args) - 1: # merge ['arg', '=', 'val'] if arg == '=' and 0 < i < len(args) - 1: # merge ['arg', '=', 'val']
new_args[-1] += f"={args[i + 1]}" new_args[-1] += f'={args[i + 1]}'
del args[i + 1] del args[i + 1]
elif arg.endswith('=') and i < len(args) - 1 and '=' not in args[i + 1]: # merge ['arg=', 'val'] elif arg.endswith('=') and i < len(args) - 1 and '=' not in args[i + 1]: # merge ['arg=', 'val']
new_args.append(f"{arg}{args[i + 1]}") new_args.append(f'{arg}{args[i + 1]}')
del args[i + 1] del args[i + 1]
elif arg.startswith('=') and i > 0: # merge ['arg', '=val'] elif arg.startswith('=') and i > 0: # merge ['arg', '=val']
new_args[-1] += arg new_args[-1] += arg
@ -223,7 +223,7 @@ def entrypoint(debug=''):
k, v = a.split('=', 1) # split on first '=' sign k, v = a.split('=', 1) # split on first '=' sign
assert v, f"missing '{k}' value" assert v, f"missing '{k}' value"
if k == 'cfg': # custom.yaml passed if k == 'cfg': # custom.yaml passed
LOGGER.info(f"Overriding {DEFAULT_CFG_PATH} with {v}") LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}')
overrides = {k: val for k, val in yaml_load(v).items() if k != 'cfg'} overrides = {k: val for k, val in yaml_load(v).items() if k != 'cfg'}
else: else:
if v.lower() == 'none': if v.lower() == 'none':
@ -237,7 +237,7 @@ def entrypoint(debug=''):
v = eval(v) v = eval(v)
overrides[k] = v overrides[k] = v
except (NameError, SyntaxError, ValueError, AssertionError) as e: except (NameError, SyntaxError, ValueError, AssertionError) as e:
check_cfg_mismatch(full_args_dict, {a: ""}, e) check_cfg_mismatch(full_args_dict, {a: ''}, e)
elif a in tasks: elif a in tasks:
overrides['task'] = a overrides['task'] = a
@ -252,7 +252,7 @@ def entrypoint(debug=''):
raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign " raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}") f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}")
else: else:
check_cfg_mismatch(full_args_dict, {a: ""}) check_cfg_mismatch(full_args_dict, {a: ''})
# Defaults # Defaults
task2model = dict(detect='yolov8n.pt', segment='yolov8n-seg.pt', classify='yolov8n-cls.pt') task2model = dict(detect='yolov8n.pt', segment='yolov8n-seg.pt', classify='yolov8n-cls.pt')
@ -287,8 +287,8 @@ def entrypoint(debug=''):
task = model.task task = model.task
overrides['task'] = task overrides['task'] = task
if mode in {'predict', 'track'} and 'source' not in overrides: if mode in {'predict', 'track'} and 'source' not in overrides:
overrides['source'] = DEFAULT_CFG.source or ROOT / "assets" if (ROOT / "assets").exists() \ overrides['source'] = DEFAULT_CFG.source or ROOT / 'assets' if (ROOT / 'assets').exists() \
else "https://ultralytics.com/images/bus.jpg" else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.") LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
elif mode in ('train', 'val'): elif mode in ('train', 'val'):
if 'data' not in overrides: if 'data' not in overrides:
@ -308,7 +308,7 @@ def entrypoint(debug=''):
def copy_default_cfg(): def copy_default_cfg():
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml') new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
shutil.copy2(DEFAULT_CFG_PATH, new_file) shutil.copy2(DEFAULT_CFG_PATH, new_file)
LOGGER.info(f"{DEFAULT_CFG_PATH} copied to {new_file}\n" LOGGER.info(f'{DEFAULT_CFG_PATH} copied to {new_file}\n'
f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8") f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8")

@ -6,11 +6,11 @@ from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
from .dataset_wrappers import MixAndRectDataset from .dataset_wrappers import MixAndRectDataset
__all__ = [ __all__ = [
"BaseDataset", 'BaseDataset',
"ClassificationDataset", 'ClassificationDataset',
"MixAndRectDataset", 'MixAndRectDataset',
"SemanticDataset", 'SemanticDataset',
"YOLODataset", 'YOLODataset',
"build_classification_dataloader", 'build_classification_dataloader',
"build_dataloader", 'build_dataloader',
"load_inference_source",] 'load_inference_source',]

@ -55,11 +55,11 @@ class Compose:
return self.transforms return self.transforms
def __repr__(self): def __repr__(self):
format_string = f"{self.__class__.__name__}(" format_string = f'{self.__class__.__name__}('
for t in self.transforms: for t in self.transforms:
format_string += "\n" format_string += '\n'
format_string += f" {t}" format_string += f' {t}'
format_string += "\n)" format_string += '\n)'
return format_string return format_string
@ -86,11 +86,11 @@ class BaseMixTransform:
if self.pre_transform is not None: if self.pre_transform is not None:
for i, data in enumerate(mix_labels): for i, data in enumerate(mix_labels):
mix_labels[i] = self.pre_transform(data) mix_labels[i] = self.pre_transform(data)
labels["mix_labels"] = mix_labels labels['mix_labels'] = mix_labels
# Mosaic or MixUp # Mosaic or MixUp
labels = self._mix_transform(labels) labels = self._mix_transform(labels)
labels.pop("mix_labels", None) labels.pop('mix_labels', None)
return labels return labels
def _mix_transform(self, labels): def _mix_transform(self, labels):
@ -109,7 +109,7 @@ class Mosaic(BaseMixTransform):
""" """
def __init__(self, dataset, imgsz=640, p=1.0, border=(0, 0)): def __init__(self, dataset, imgsz=640, p=1.0, border=(0, 0)):
assert 0 <= p <= 1.0, "The probability should be in range [0, 1]. " f"got {p}." assert 0 <= p <= 1.0, 'The probability should be in range [0, 1]. ' f'got {p}.'
super().__init__(dataset=dataset, p=p) super().__init__(dataset=dataset, p=p)
self.dataset = dataset self.dataset = dataset
self.imgsz = imgsz self.imgsz = imgsz
@ -120,15 +120,15 @@ class Mosaic(BaseMixTransform):
def _mix_transform(self, labels): def _mix_transform(self, labels):
mosaic_labels = [] mosaic_labels = []
assert labels.get("rect_shape", None) is None, "rect and mosaic is exclusive." assert labels.get('rect_shape', None) is None, 'rect and mosaic is exclusive.'
assert len(labels.get("mix_labels", [])) > 0, "There are no other images for mosaic augment." assert len(labels.get('mix_labels', [])) > 0, 'There are no other images for mosaic augment.'
s = self.imgsz s = self.imgsz
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
for i in range(4): for i in range(4):
labels_patch = (labels if i == 0 else labels["mix_labels"][i - 1]).copy() labels_patch = (labels if i == 0 else labels['mix_labels'][i - 1]).copy()
# Load image # Load image
img = labels_patch["img"] img = labels_patch['img']
h, w = labels_patch.pop("resized_shape") h, w = labels_patch.pop('resized_shape')
# place img in img4 # place img in img4
if i == 0: # top left if i == 0: # top left
@ -152,15 +152,15 @@ class Mosaic(BaseMixTransform):
labels_patch = self._update_labels(labels_patch, padw, padh) labels_patch = self._update_labels(labels_patch, padw, padh)
mosaic_labels.append(labels_patch) mosaic_labels.append(labels_patch)
final_labels = self._cat_labels(mosaic_labels) final_labels = self._cat_labels(mosaic_labels)
final_labels["img"] = img4 final_labels['img'] = img4
return final_labels return final_labels
def _update_labels(self, labels, padw, padh): def _update_labels(self, labels, padw, padh):
"""Update labels""" """Update labels"""
nh, nw = labels["img"].shape[:2] nh, nw = labels['img'].shape[:2]
labels["instances"].convert_bbox(format="xyxy") labels['instances'].convert_bbox(format='xyxy')
labels["instances"].denormalize(nw, nh) labels['instances'].denormalize(nw, nh)
labels["instances"].add_padding(padw, padh) labels['instances'].add_padding(padw, padh)
return labels return labels
def _cat_labels(self, mosaic_labels): def _cat_labels(self, mosaic_labels):
@ -169,16 +169,16 @@ class Mosaic(BaseMixTransform):
cls = [] cls = []
instances = [] instances = []
for labels in mosaic_labels: for labels in mosaic_labels:
cls.append(labels["cls"]) cls.append(labels['cls'])
instances.append(labels["instances"]) instances.append(labels['instances'])
final_labels = { final_labels = {
"im_file": mosaic_labels[0]["im_file"], 'im_file': mosaic_labels[0]['im_file'],
"ori_shape": mosaic_labels[0]["ori_shape"], 'ori_shape': mosaic_labels[0]['ori_shape'],
"resized_shape": (self.imgsz * 2, self.imgsz * 2), 'resized_shape': (self.imgsz * 2, self.imgsz * 2),
"cls": np.concatenate(cls, 0), 'cls': np.concatenate(cls, 0),
"instances": Instances.concatenate(instances, axis=0), 'instances': Instances.concatenate(instances, axis=0),
"mosaic_border": self.border} 'mosaic_border': self.border}
final_labels["instances"].clip(self.imgsz * 2, self.imgsz * 2) final_labels['instances'].clip(self.imgsz * 2, self.imgsz * 2)
return final_labels return final_labels
@ -193,10 +193,10 @@ class MixUp(BaseMixTransform):
def _mix_transform(self, labels): def _mix_transform(self, labels):
# Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf # Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0 r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
labels2 = labels["mix_labels"][0] labels2 = labels['mix_labels'][0]
labels["img"] = (labels["img"] * r + labels2["img"] * (1 - r)).astype(np.uint8) labels['img'] = (labels['img'] * r + labels2['img'] * (1 - r)).astype(np.uint8)
labels["instances"] = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0) labels['instances'] = Instances.concatenate([labels['instances'], labels2['instances']], axis=0)
labels["cls"] = np.concatenate([labels["cls"], labels2["cls"]], 0) labels['cls'] = np.concatenate([labels['cls'], labels2['cls']], 0)
return labels return labels
@ -338,18 +338,18 @@ class RandomPerspective:
Args: Args:
labels(Dict): a dict of `bboxes`, `segments`, `keypoints`. labels(Dict): a dict of `bboxes`, `segments`, `keypoints`.
""" """
if self.pre_transform and "mosaic_border" not in labels: if self.pre_transform and 'mosaic_border' not in labels:
labels = self.pre_transform(labels) labels = self.pre_transform(labels)
labels.pop("ratio_pad") # do not need ratio pad labels.pop('ratio_pad') # do not need ratio pad
img = labels["img"] img = labels['img']
cls = labels["cls"] cls = labels['cls']
instances = labels.pop("instances") instances = labels.pop('instances')
# make sure the coord formats are right # make sure the coord formats are right
instances.convert_bbox(format="xyxy") instances.convert_bbox(format='xyxy')
instances.denormalize(*img.shape[:2][::-1]) instances.denormalize(*img.shape[:2][::-1])
border = labels.pop("mosaic_border", self.border) border = labels.pop('mosaic_border', self.border)
self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h
# M is affine matrix # M is affine matrix
# scale for func:`box_candidates` # scale for func:`box_candidates`
@ -365,7 +365,7 @@ class RandomPerspective:
if keypoints is not None: if keypoints is not None:
keypoints = self.apply_keypoints(keypoints, M) keypoints = self.apply_keypoints(keypoints, M)
new_instances = Instances(bboxes, segments, keypoints, bbox_format="xyxy", normalized=False) new_instances = Instances(bboxes, segments, keypoints, bbox_format='xyxy', normalized=False)
# clip # clip
new_instances.clip(*self.size) new_instances.clip(*self.size)
@ -375,10 +375,10 @@ class RandomPerspective:
i = self.box_candidates(box1=instances.bboxes.T, i = self.box_candidates(box1=instances.bboxes.T,
box2=new_instances.bboxes.T, box2=new_instances.bboxes.T,
area_thr=0.01 if len(segments) else 0.10) area_thr=0.01 if len(segments) else 0.10)
labels["instances"] = new_instances[i] labels['instances'] = new_instances[i]
labels["cls"] = cls[i] labels['cls'] = cls[i]
labels["img"] = img labels['img'] = img
labels["resized_shape"] = img.shape[:2] labels['resized_shape'] = img.shape[:2]
return labels return labels
def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n) def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
@ -397,7 +397,7 @@ class RandomHSV:
self.vgain = vgain self.vgain = vgain
def __call__(self, labels): def __call__(self, labels):
img = labels["img"] img = labels['img']
if self.hgain or self.sgain or self.vgain: if self.hgain or self.sgain or self.vgain:
r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains
hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV)) hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
@ -415,30 +415,30 @@ class RandomHSV:
class RandomFlip: class RandomFlip:
def __init__(self, p=0.5, direction="horizontal") -> None: def __init__(self, p=0.5, direction='horizontal') -> None:
assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}" assert direction in ['horizontal', 'vertical'], f'Support direction `horizontal` or `vertical`, got {direction}'
assert 0 <= p <= 1.0 assert 0 <= p <= 1.0
self.p = p self.p = p
self.direction = direction self.direction = direction
def __call__(self, labels): def __call__(self, labels):
img = labels["img"] img = labels['img']
instances = labels.pop("instances") instances = labels.pop('instances')
instances.convert_bbox(format="xywh") instances.convert_bbox(format='xywh')
h, w = img.shape[:2] h, w = img.shape[:2]
h = 1 if instances.normalized else h h = 1 if instances.normalized else h
w = 1 if instances.normalized else w w = 1 if instances.normalized else w
# Flip up-down # Flip up-down
if self.direction == "vertical" and random.random() < self.p: if self.direction == 'vertical' and random.random() < self.p:
img = np.flipud(img) img = np.flipud(img)
instances.flipud(h) instances.flipud(h)
if self.direction == "horizontal" and random.random() < self.p: if self.direction == 'horizontal' and random.random() < self.p:
img = np.fliplr(img) img = np.fliplr(img)
instances.fliplr(w) instances.fliplr(w)
labels["img"] = np.ascontiguousarray(img) labels['img'] = np.ascontiguousarray(img)
labels["instances"] = instances labels['instances'] = instances
return labels return labels
@ -455,9 +455,9 @@ class LetterBox:
def __call__(self, labels=None, image=None): def __call__(self, labels=None, image=None):
if labels is None: if labels is None:
labels = {} labels = {}
img = labels.get("img") if image is None else image img = labels.get('img') if image is None else image
shape = img.shape[:2] # current shape [height, width] shape = img.shape[:2] # current shape [height, width]
new_shape = labels.pop("rect_shape", self.new_shape) new_shape = labels.pop('rect_shape', self.new_shape)
if isinstance(new_shape, int): if isinstance(new_shape, int):
new_shape = (new_shape, new_shape) new_shape = (new_shape, new_shape)
@ -479,8 +479,8 @@ class LetterBox:
dw /= 2 # divide padding into 2 sides dw /= 2 # divide padding into 2 sides
dh /= 2 dh /= 2
if labels.get("ratio_pad"): if labels.get('ratio_pad'):
labels["ratio_pad"] = (labels["ratio_pad"], (dw, dh)) # for evaluation labels['ratio_pad'] = (labels['ratio_pad'], (dw, dh)) # for evaluation
if shape[::-1] != new_unpad: # resize if shape[::-1] != new_unpad: # resize
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
@ -491,18 +491,18 @@ class LetterBox:
if len(labels): if len(labels):
labels = self._update_labels(labels, ratio, dw, dh) labels = self._update_labels(labels, ratio, dw, dh)
labels["img"] = img labels['img'] = img
labels["resized_shape"] = new_shape labels['resized_shape'] = new_shape
return labels return labels
else: else:
return img return img
def _update_labels(self, labels, ratio, padw, padh): def _update_labels(self, labels, ratio, padw, padh):
"""Update labels""" """Update labels"""
labels["instances"].convert_bbox(format="xyxy") labels['instances'].convert_bbox(format='xyxy')
labels["instances"].denormalize(*labels["img"].shape[:2][::-1]) labels['instances'].denormalize(*labels['img'].shape[:2][::-1])
labels["instances"].scale(*ratio) labels['instances'].scale(*ratio)
labels["instances"].add_padding(padw, padh) labels['instances'].add_padding(padw, padh)
return labels return labels
@ -513,11 +513,11 @@ class CopyPaste:
def __call__(self, labels): def __call__(self, labels):
# Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy) # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
im = labels["img"] im = labels['img']
cls = labels["cls"] cls = labels['cls']
h, w = im.shape[:2] h, w = im.shape[:2]
instances = labels.pop("instances") instances = labels.pop('instances')
instances.convert_bbox(format="xyxy") instances.convert_bbox(format='xyxy')
instances.denormalize(w, h) instances.denormalize(w, h)
if self.p and len(instances.segments): if self.p and len(instances.segments):
n = len(instances) n = len(instances)
@ -540,9 +540,9 @@ class CopyPaste:
i = cv2.flip(im_new, 1).astype(bool) i = cv2.flip(im_new, 1).astype(bool)
im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
labels["img"] = im labels['img'] = im
labels["cls"] = cls labels['cls'] = cls
labels["instances"] = instances labels['instances'] = instances
return labels return labels
@ -551,11 +551,11 @@ class Albumentations:
def __init__(self, p=1.0): def __init__(self, p=1.0):
self.p = p self.p = p
self.transform = None self.transform = None
prefix = colorstr("albumentations: ") prefix = colorstr('albumentations: ')
try: try:
import albumentations as A import albumentations as A
check_version(A.__version__, "1.0.3", hard=True) # version requirement check_version(A.__version__, '1.0.3', hard=True) # version requirement
T = [ T = [
A.Blur(p=0.01), A.Blur(p=0.01),
@ -565,28 +565,28 @@ class Albumentations:
A.RandomBrightnessContrast(p=0.0), A.RandomBrightnessContrast(p=0.0),
A.RandomGamma(p=0.0), A.RandomGamma(p=0.0),
A.ImageCompression(quality_lower=75, p=0.0),] # transforms A.ImageCompression(quality_lower=75, p=0.0),] # transforms
self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"])) self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p)) LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
except ImportError: # package not installed, skip except ImportError: # package not installed, skip
pass pass
except Exception as e: except Exception as e:
LOGGER.info(f"{prefix}{e}") LOGGER.info(f'{prefix}{e}')
def __call__(self, labels): def __call__(self, labels):
im = labels["img"] im = labels['img']
cls = labels["cls"] cls = labels['cls']
if len(cls): if len(cls):
labels["instances"].convert_bbox("xywh") labels['instances'].convert_bbox('xywh')
labels["instances"].normalize(*im.shape[:2][::-1]) labels['instances'].normalize(*im.shape[:2][::-1])
bboxes = labels["instances"].bboxes bboxes = labels['instances'].bboxes
# TODO: add supports of segments and keypoints # TODO: add supports of segments and keypoints
if self.transform and random.random() < self.p: if self.transform and random.random() < self.p:
new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
labels["img"] = new["image"] labels['img'] = new['image']
labels["cls"] = np.array(new["class_labels"]) labels['cls'] = np.array(new['class_labels'])
bboxes = np.array(new["bboxes"]) bboxes = np.array(new['bboxes'])
labels["instances"].update(bboxes=bboxes) labels['instances'].update(bboxes=bboxes)
return labels return labels
@ -594,7 +594,7 @@ class Albumentations:
class Format: class Format:
def __init__(self, def __init__(self,
bbox_format="xywh", bbox_format='xywh',
normalize=True, normalize=True,
return_mask=False, return_mask=False,
return_keypoint=False, return_keypoint=False,
@ -610,10 +610,10 @@ class Format:
self.batch_idx = batch_idx # keep the batch indexes self.batch_idx = batch_idx # keep the batch indexes
def __call__(self, labels): def __call__(self, labels):
img = labels.pop("img") img = labels.pop('img')
h, w = img.shape[:2] h, w = img.shape[:2]
cls = labels.pop("cls") cls = labels.pop('cls')
instances = labels.pop("instances") instances = labels.pop('instances')
instances.convert_bbox(format=self.bbox_format) instances.convert_bbox(format=self.bbox_format)
instances.denormalize(w, h) instances.denormalize(w, h)
nl = len(instances) nl = len(instances)
@ -625,17 +625,17 @@ class Format:
else: else:
masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio,
img.shape[1] // self.mask_ratio) img.shape[1] // self.mask_ratio)
labels["masks"] = masks labels['masks'] = masks
if self.normalize: if self.normalize:
instances.normalize(w, h) instances.normalize(w, h)
labels["img"] = self._format_img(img) labels['img'] = self._format_img(img)
labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl) labels['cls'] = torch.from_numpy(cls) if nl else torch.zeros(nl)
labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4)) labels['bboxes'] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
if self.return_keypoint: if self.return_keypoint:
labels["keypoints"] = torch.from_numpy(instances.keypoints) if nl else torch.zeros((nl, 17, 2)) labels['keypoints'] = torch.from_numpy(instances.keypoints) if nl else torch.zeros((nl, 17, 2))
# then we can use collate_fn # then we can use collate_fn
if self.batch_idx: if self.batch_idx:
labels["batch_idx"] = torch.zeros(nl) labels['batch_idx'] = torch.zeros(nl)
return labels return labels
def _format_img(self, img): def _format_img(self, img):
@ -676,15 +676,15 @@ def v8_transforms(dataset, imgsz, hyp):
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup), MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
Albumentations(p=1.0), Albumentations(p=1.0),
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v), RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
RandomFlip(direction="vertical", p=hyp.flipud), RandomFlip(direction='vertical', p=hyp.flipud),
RandomFlip(direction="horizontal", p=hyp.fliplr),]) # transforms RandomFlip(direction='horizontal', p=hyp.fliplr),]) # transforms
# Classification augmentations ----------------------------------------------------------------------------------------- # Classification augmentations -----------------------------------------------------------------------------------------
def classify_transforms(size=224): def classify_transforms(size=224):
# Transforms to apply if albumentations not installed # Transforms to apply if albumentations not installed
if not isinstance(size, int): if not isinstance(size, int):
raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)") raise TypeError(f'classify_transforms() size {size} must be integer, not (list, tuple)')
# T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)]) # T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)]) return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
@ -701,17 +701,17 @@ def classify_albumentations(
auto_aug=False, auto_aug=False,
): ):
# YOLOv8 classification Albumentations (optional, only used if package is installed) # YOLOv8 classification Albumentations (optional, only used if package is installed)
prefix = colorstr("albumentations: ") prefix = colorstr('albumentations: ')
try: try:
import albumentations as A import albumentations as A
from albumentations.pytorch import ToTensorV2 from albumentations.pytorch import ToTensorV2
check_version(A.__version__, "1.0.3", hard=True) # version requirement check_version(A.__version__, '1.0.3', hard=True) # version requirement
if augment: # Resize and crop if augment: # Resize and crop
T = [A.RandomResizedCrop(height=size, width=size, scale=scale)] T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
if auto_aug: if auto_aug:
# TODO: implement AugMix, AutoAug & RandAug in albumentation # TODO: implement AugMix, AutoAug & RandAug in albumentation
LOGGER.info(f"{prefix}auto augmentations are currently not supported") LOGGER.info(f'{prefix}auto augmentations are currently not supported')
else: else:
if hflip > 0: if hflip > 0:
T += [A.HorizontalFlip(p=hflip)] T += [A.HorizontalFlip(p=hflip)]
@ -723,13 +723,13 @@ def classify_albumentations(
else: # Use fixed crop for eval set (reproducibility) else: # Use fixed crop for eval set (reproducibility)
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)] T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p)) LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
return A.Compose(T) return A.Compose(T)
except ImportError: # package not installed, skip except ImportError: # package not installed, skip
pass pass
except Exception as e: except Exception as e:
LOGGER.info(f"{prefix}{e}") LOGGER.info(f'{prefix}{e}')
class ClassifyLetterBox: class ClassifyLetterBox:

@ -31,7 +31,7 @@ class BaseDataset(Dataset):
cache=False, cache=False,
augment=True, augment=True,
hyp=None, hyp=None,
prefix="", prefix='',
rect=False, rect=False,
batch_size=None, batch_size=None,
stride=32, stride=32,
@ -63,7 +63,7 @@ class BaseDataset(Dataset):
# cache stuff # cache stuff
self.ims = [None] * self.ni self.ims = [None] * self.ni
self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files] self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
if cache: if cache:
self.cache_images(cache) self.cache_images(cache)
@ -77,21 +77,21 @@ class BaseDataset(Dataset):
for p in img_path if isinstance(img_path, list) else [img_path]: for p in img_path if isinstance(img_path, list) else [img_path]:
p = Path(p) # os-agnostic p = Path(p) # os-agnostic
if p.is_dir(): # dir if p.is_dir(): # dir
f += glob.glob(str(p / "**" / "*.*"), recursive=True) f += glob.glob(str(p / '**' / '*.*'), recursive=True)
# f = list(p.rglob('*.*')) # pathlib # f = list(p.rglob('*.*')) # pathlib
elif p.is_file(): # file elif p.is_file(): # file
with open(p) as t: with open(p) as t:
t = t.read().strip().splitlines() t = t.read().strip().splitlines()
parent = str(p.parent) + os.sep parent = str(p.parent) + os.sep
f += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global path f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
# f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib) # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
else: else:
raise FileNotFoundError(f"{self.prefix}{p} does not exist") raise FileNotFoundError(f'{self.prefix}{p} does not exist')
im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS) im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
assert im_files, f"{self.prefix}No images found" assert im_files, f'{self.prefix}No images found'
except Exception as e: except Exception as e:
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e
return im_files return im_files
def update_labels(self, include_class: Optional[list]): def update_labels(self, include_class: Optional[list]):
@ -99,16 +99,16 @@ class BaseDataset(Dataset):
include_class_array = np.array(include_class).reshape(1, -1) include_class_array = np.array(include_class).reshape(1, -1)
for i in range(len(self.labels)): for i in range(len(self.labels)):
if include_class: if include_class:
cls = self.labels[i]["cls"] cls = self.labels[i]['cls']
bboxes = self.labels[i]["bboxes"] bboxes = self.labels[i]['bboxes']
segments = self.labels[i]["segments"] segments = self.labels[i]['segments']
j = (cls == include_class_array).any(1) j = (cls == include_class_array).any(1)
self.labels[i]["cls"] = cls[j] self.labels[i]['cls'] = cls[j]
self.labels[i]["bboxes"] = bboxes[j] self.labels[i]['bboxes'] = bboxes[j]
if segments: if segments:
self.labels[i]["segments"] = segments[j] self.labels[i]['segments'] = segments[j]
if self.single_cls: if self.single_cls:
self.labels[i]["cls"][:, 0] = 0 self.labels[i]['cls'][:, 0] = 0
def load_image(self, i): def load_image(self, i):
# Loads 1 image from dataset index 'i', returns (im, resized hw) # Loads 1 image from dataset index 'i', returns (im, resized hw)
@ -119,7 +119,7 @@ class BaseDataset(Dataset):
else: # read image else: # read image
im = cv2.imread(f) # BGR im = cv2.imread(f) # BGR
if im is None: if im is None:
raise FileNotFoundError(f"Image Not Found {f}") raise FileNotFoundError(f'Image Not Found {f}')
h0, w0 = im.shape[:2] # orig hw h0, w0 = im.shape[:2] # orig hw
r = self.imgsz / max(h0, w0) # ratio r = self.imgsz / max(h0, w0) # ratio
if r != 1: # if sizes are not equal if r != 1: # if sizes are not equal
@ -132,17 +132,17 @@ class BaseDataset(Dataset):
# cache images to memory or disk # cache images to memory or disk
gb = 0 # Gigabytes of cached images gb = 0 # Gigabytes of cached images
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
fcn = self.cache_images_to_disk if cache == "disk" else self.load_image fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
with ThreadPool(NUM_THREADS) as pool: with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(fcn, range(self.ni)) results = pool.imap(fcn, range(self.ni))
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0) pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
for i, x in pbar: for i, x in pbar:
if cache == "disk": if cache == 'disk':
gb += self.npy_files[i].stat().st_size gb += self.npy_files[i].stat().st_size
else: # 'ram' else: # 'ram'
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i) self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
gb += self.ims[i].nbytes gb += self.ims[i].nbytes
pbar.desc = f"{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})" pbar.desc = f'{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})'
pbar.close() pbar.close()
def cache_images_to_disk(self, i): def cache_images_to_disk(self, i):
@ -155,7 +155,7 @@ class BaseDataset(Dataset):
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
nb = bi[-1] + 1 # number of batches nb = bi[-1] + 1 # number of batches
s = np.array([x.pop("shape") for x in self.labels]) # hw s = np.array([x.pop('shape') for x in self.labels]) # hw
ar = s[:, 0] / s[:, 1] # aspect ratio ar = s[:, 0] / s[:, 1] # aspect ratio
irect = ar.argsort() irect = ar.argsort()
self.im_files = [self.im_files[i] for i in irect] self.im_files = [self.im_files[i] for i in irect]
@ -180,14 +180,14 @@ class BaseDataset(Dataset):
def get_label_info(self, index): def get_label_info(self, index):
label = self.labels[index].copy() label = self.labels[index].copy()
label.pop("shape", None) # shape is for rect, remove it label.pop('shape', None) # shape is for rect, remove it
label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index) label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index)
label["ratio_pad"] = ( label['ratio_pad'] = (
label["resized_shape"][0] / label["ori_shape"][0], label['resized_shape'][0] / label['ori_shape'][0],
label["resized_shape"][1] / label["ori_shape"][1], label['resized_shape'][1] / label['ori_shape'][1],
) # for evaluation ) # for evaluation
if self.rect: if self.rect:
label["rect_shape"] = self.batch_shapes[self.batch[index]] label['rect_shape'] = self.batch_shapes[self.batch[index]]
label = self.update_labels_info(label) label = self.update_labels_info(label)
return label return label

@ -28,7 +28,7 @@ class InfiniteDataLoader(dataloader.DataLoader):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler)) object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__() self.iterator = super().__iter__()
def __len__(self): def __len__(self):
@ -61,9 +61,9 @@ def seed_worker(worker_id):
random.seed(worker_seed) random.seed(worker_seed)
def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, rank=-1, mode="train"): def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, rank=-1, mode='train'):
assert mode in ["train", "val"] assert mode in ['train', 'val']
shuffle = mode == "train" shuffle = mode == 'train'
if cfg.rect and shuffle: if cfg.rect and shuffle:
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False") LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
shuffle = False shuffle = False
@ -72,21 +72,21 @@ def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, ra
img_path=img_path, img_path=img_path,
imgsz=cfg.imgsz, imgsz=cfg.imgsz,
batch_size=batch, batch_size=batch,
augment=mode == "train", # augmentation augment=mode == 'train', # augmentation
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
rect=cfg.rect or rect, # rectangular batches rect=cfg.rect or rect, # rectangular batches
cache=cfg.cache or None, cache=cfg.cache or None,
single_cls=cfg.single_cls or False, single_cls=cfg.single_cls or False,
stride=int(stride), stride=int(stride),
pad=0.0 if mode == "train" else 0.5, pad=0.0 if mode == 'train' else 0.5,
prefix=colorstr(f"{mode}: "), prefix=colorstr(f'{mode}: '),
use_segments=cfg.task == "segment", use_segments=cfg.task == 'segment',
use_keypoints=cfg.task == "keypoint", use_keypoints=cfg.task == 'keypoint',
names=names) names=names)
batch = min(batch, len(dataset)) batch = min(batch, len(dataset))
nd = torch.cuda.device_count() # number of CUDA devices nd = torch.cuda.device_count() # number of CUDA devices
workers = cfg.workers if mode == "train" else cfg.workers * 2 workers = cfg.workers if mode == 'train' else cfg.workers * 2
nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers]) # number of workers nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers]) # number of workers
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if cfg.image_weights or cfg.close_mosaic else InfiniteDataLoader # allow attribute updates loader = DataLoader if cfg.image_weights or cfg.close_mosaic else InfiniteDataLoader # allow attribute updates
@ -98,7 +98,7 @@ def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, ra
num_workers=nw, num_workers=nw,
sampler=sampler, sampler=sampler,
pin_memory=PIN_MEMORY, pin_memory=PIN_MEMORY,
collate_fn=getattr(dataset, "collate_fn", None), collate_fn=getattr(dataset, 'collate_fn', None),
worker_init_fn=seed_worker, worker_init_fn=seed_worker,
generator=generator), dataset generator=generator), dataset
@ -151,7 +151,7 @@ def check_source(source):
from_img = True from_img = True
else: else:
raise Exception( raise Exception(
"Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict") 'Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict')
return source, webcam, screenshot, from_img, in_memory return source, webcam, screenshot, from_img, in_memory

@ -47,7 +47,7 @@ class LoadStreams:
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc' # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
check_requirements(('pafy', 'youtube_dl==2020.12.2')) check_requirements(('pafy', 'youtube_dl==2020.12.2'))
import pafy # noqa import pafy # noqa
s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL s = pafy.new(s).getbest(preftype='mp4').url # YouTube URL
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
if s == 0 and (is_colab() or is_kaggle()): if s == 0 and (is_colab() or is_kaggle()):
raise NotImplementedError("'source=0' webcam not supported in Colab and Kaggle notebooks. " raise NotImplementedError("'source=0' webcam not supported in Colab and Kaggle notebooks. "
@ -65,7 +65,7 @@ class LoadStreams:
if not success or self.imgs[i] is None: if not success or self.imgs[i] is None:
raise ConnectionError(f'{st}Failed to read images from {s}') raise ConnectionError(f'{st}Failed to read images from {s}')
self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True) self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
LOGGER.info(f"{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)") LOGGER.info(f'{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)')
self.threads[i].start() self.threads[i].start()
LOGGER.info('') # newline LOGGER.info('') # newline
@ -145,11 +145,11 @@ class LoadScreenshots:
# Parse monitor shape # Parse monitor shape
monitor = self.sct.monitors[self.screen] monitor = self.sct.monitors[self.screen]
self.top = monitor["top"] if top is None else (monitor["top"] + top) self.top = monitor['top'] if top is None else (monitor['top'] + top)
self.left = monitor["left"] if left is None else (monitor["left"] + left) self.left = monitor['left'] if left is None else (monitor['left'] + left)
self.width = width or monitor["width"] self.width = width or monitor['width']
self.height = height or monitor["height"] self.height = height or monitor['height']
self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height} self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
def __iter__(self): def __iter__(self):
return self return self
@ -157,7 +157,7 @@ class LoadScreenshots:
def __next__(self): def __next__(self):
# mss screen capture: get raw pixels from the screen as np array # mss screen capture: get raw pixels from the screen as np array
im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: " s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: '
if self.transforms: if self.transforms:
im = self.transforms(im0) # transforms im = self.transforms(im0) # transforms
@ -172,7 +172,7 @@ class LoadScreenshots:
class LoadImages: class LoadImages:
# YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4` # YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`
def __init__(self, path, imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1): def __init__(self, path, imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
path = Path(path).read_text().rsplit() path = Path(path).read_text().rsplit()
files = [] files = []
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]: for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
@ -290,12 +290,12 @@ class LoadPilAndNumpy:
self.transforms = transforms self.transforms = transforms
self.mode = 'image' self.mode = 'image'
# generate fake paths # generate fake paths
self.paths = [f"image{i}.jpg" for i in range(len(self.im0))] self.paths = [f'image{i}.jpg' for i in range(len(self.im0))]
self.bs = len(self.im0) self.bs = len(self.im0)
@staticmethod @staticmethod
def _single_check(im): def _single_check(im):
assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}" assert isinstance(im, (Image.Image, np.ndarray)), f'Expected PIL/np.ndarray image type, but got {type(im)}'
if isinstance(im, Image.Image): if isinstance(im, Image.Image):
im = np.asarray(im)[:, :, ::-1] im = np.asarray(im)[:, :, ::-1]
im = np.ascontiguousarray(im) # contiguous im = np.ascontiguousarray(im) # contiguous
@ -338,16 +338,16 @@ def autocast_list(source):
elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
files.append(im) files.append(im)
else: else:
raise TypeError(f"type {type(im).__name__} is not a supported Ultralytics prediction source type. \n" raise TypeError(f'type {type(im).__name__} is not a supported Ultralytics prediction source type. \n'
f"See https://docs.ultralytics.com/predict for supported source types.") f'See https://docs.ultralytics.com/predict for supported source types.')
return files return files
LOADERS = [LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots] LOADERS = [LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots]
if __name__ == "__main__": if __name__ == '__main__':
img = cv2.imread(str(ROOT / "assets/bus.jpg")) img = cv2.imread(str(ROOT / 'assets/bus.jpg'))
dataset = LoadPilAndNumpy(im0=img) dataset = LoadPilAndNumpy(im0=img)
for d in dataset: for d in dataset:
print(d[0]) print(d[0])

@ -92,7 +92,7 @@ def exif_transpose(image):
if method is not None: if method is not None:
image = image.transpose(method) image = image.transpose(method)
del exif[0x0112] del exif[0x0112]
image.info["exif"] = exif.tobytes() image.info['exif'] = exif.tobytes()
return image return image
@ -217,11 +217,11 @@ class LoadScreenshots:
# Parse monitor shape # Parse monitor shape
monitor = self.sct.monitors[self.screen] monitor = self.sct.monitors[self.screen]
self.top = monitor["top"] if top is None else (monitor["top"] + top) self.top = monitor['top'] if top is None else (monitor['top'] + top)
self.left = monitor["left"] if left is None else (monitor["left"] + left) self.left = monitor['left'] if left is None else (monitor['left'] + left)
self.width = width or monitor["width"] self.width = width or monitor['width']
self.height = height or monitor["height"] self.height = height or monitor['height']
self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height} self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
def __iter__(self): def __iter__(self):
return self return self
@ -229,7 +229,7 @@ class LoadScreenshots:
def __next__(self): def __next__(self):
# mss screen capture: get raw pixels from the screen as np array # mss screen capture: get raw pixels from the screen as np array
im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: " s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: '
if self.transforms: if self.transforms:
im = self.transforms(im0) # transforms im = self.transforms(im0) # transforms
@ -244,7 +244,7 @@ class LoadScreenshots:
class LoadImages: class LoadImages:
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4` # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1): def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
path = Path(path).read_text().rsplit() path = Path(path).read_text().rsplit()
files = [] files = []
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]: for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
@ -363,7 +363,7 @@ class LoadStreams:
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc' # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
check_requirements(('pafy', 'youtube_dl==2020.12.2')) check_requirements(('pafy', 'youtube_dl==2020.12.2'))
import pafy import pafy
s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL s = pafy.new(s).getbest(preftype='mp4').url # YouTube URL
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
if s == 0: if s == 0:
assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.' assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
@ -378,7 +378,7 @@ class LoadStreams:
_, self.imgs[i] = cap.read() # guarantee first frame _, self.imgs[i] = cap.read() # guarantee first frame
self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True) self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)") LOGGER.info(f'{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)')
self.threads[i].start() self.threads[i].start()
LOGGER.info('') # newline LOGGER.info('') # newline
@ -500,7 +500,7 @@ class LoadImagesAndLabels(Dataset):
# Display cache # Display cache
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
if exists and LOCAL_RANK in {-1, 0}: if exists and LOCAL_RANK in {-1, 0}:
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt" d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
if cache['msgs']: if cache['msgs']:
LOGGER.info('\n'.join(cache['msgs'])) # display warnings LOGGER.info('\n'.join(cache['msgs'])) # display warnings
@ -604,8 +604,8 @@ class LoadImagesAndLabels(Dataset):
mem = psutil.virtual_memory() mem = psutil.virtual_memory()
cache = mem_required * (1 + safety_margin) < mem.available # to cache or not to cache, that is the question cache = mem_required * (1 + safety_margin) < mem.available # to cache or not to cache, that is the question
if not cache: if not cache:
LOGGER.info(f"{prefix}{mem_required / gb:.1f}GB RAM required, " LOGGER.info(f'{prefix}{mem_required / gb:.1f}GB RAM required, '
f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, " f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
f"{'caching images ✅' if cache else 'not caching images ⚠️'}") f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
return cache return cache
@ -615,7 +615,7 @@ class LoadImagesAndLabels(Dataset):
path.unlink() # remove *.cache file if exists path.unlink() # remove *.cache file if exists
x = {} # dict x = {} # dict
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{prefix}Scanning {path.parent / path.stem}..." desc = f'{prefix}Scanning {path.parent / path.stem}...'
total = len(self.im_files) total = len(self.im_files)
with ThreadPool(NUM_THREADS) as pool: with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))) results = pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix)))
@ -629,7 +629,7 @@ class LoadImagesAndLabels(Dataset):
x[im_file] = [lb, shape, segments] x[im_file] = [lb, shape, segments]
if msg: if msg:
msgs.append(msg) msgs.append(msg)
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt" pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt'
pbar.close() pbar.close()
if msgs: if msgs:
@ -1060,7 +1060,7 @@ class HUBDatasetStats():
if zipped: if zipped:
data['path'] = data_dir data['path'] = data_dir
except Exception as e: except Exception as e:
raise Exception("error/HUB/dataset_stats/yaml_load") from e raise Exception('error/HUB/dataset_stats/yaml_load') from e
check_det_dataset(data, autodownload) # download dataset if missing check_det_dataset(data, autodownload) # download dataset if missing
self.hub_dir = Path(data['path'] + '-hub') self.hub_dir = Path(data['path'] + '-hub')
@ -1187,7 +1187,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
else: # read image else: # read image
im = cv2.imread(f) # BGR im = cv2.imread(f) # BGR
if self.album_transforms: if self.album_transforms:
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"] sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image']
else: else:
sample = self.torch_transforms(im) sample = self.torch_transforms(im)
return sample, j return sample, j

@ -28,7 +28,7 @@ class YOLODataset(BaseDataset):
cache=False, cache=False,
augment=True, augment=True,
hyp=None, hyp=None,
prefix="", prefix='',
rect=False, rect=False,
batch_size=None, batch_size=None,
stride=32, stride=32,
@ -40,14 +40,14 @@ class YOLODataset(BaseDataset):
self.use_segments = use_segments self.use_segments = use_segments
self.use_keypoints = use_keypoints self.use_keypoints = use_keypoints
self.names = names self.names = names
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints." assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.'
super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls) super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls)
def cache_labels(self, path=Path("./labels.cache")): def cache_labels(self, path=Path('./labels.cache')):
# Cache dataset labels, check images and read shapes # Cache dataset labels, check images and read shapes
x = {"labels": []} x = {'labels': []}
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{self.prefix}Scanning {path.parent / path.stem}..." desc = f'{self.prefix}Scanning {path.parent / path.stem}...'
total = len(self.im_files) total = len(self.im_files)
with ThreadPool(NUM_THREADS) as pool: with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(func=verify_image_label, results = pool.imap(func=verify_image_label,
@ -60,7 +60,7 @@ class YOLODataset(BaseDataset):
ne += ne_f ne += ne_f
nc += nc_f nc += nc_f
if im_file: if im_file:
x["labels"].append( x['labels'].append(
dict( dict(
im_file=im_file, im_file=im_file,
shape=shape, shape=shape,
@ -69,68 +69,68 @@ class YOLODataset(BaseDataset):
segments=segments, segments=segments,
keypoints=keypoint, keypoints=keypoint,
normalized=True, normalized=True,
bbox_format="xywh")) bbox_format='xywh'))
if msg: if msg:
msgs.append(msg) msgs.append(msg)
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt" pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt'
pbar.close() pbar.close()
if msgs: if msgs:
LOGGER.info("\n".join(msgs)) LOGGER.info('\n'.join(msgs))
if nf == 0: if nf == 0:
LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}") LOGGER.warning(f'{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}')
x["hash"] = get_hash(self.label_files + self.im_files) x['hash'] = get_hash(self.label_files + self.im_files)
x["results"] = nf, nm, ne, nc, len(self.im_files) x['results'] = nf, nm, ne, nc, len(self.im_files)
x["msgs"] = msgs # warnings x['msgs'] = msgs # warnings
x["version"] = self.cache_version # cache version x['version'] = self.cache_version # cache version
if is_dir_writeable(path.parent): if is_dir_writeable(path.parent):
if path.exists(): if path.exists():
path.unlink() # remove *.cache file if exists path.unlink() # remove *.cache file if exists
np.save(str(path), x) # save cache for next time np.save(str(path), x) # save cache for next time
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
LOGGER.info(f"{self.prefix}New cache created: {path}") LOGGER.info(f'{self.prefix}New cache created: {path}')
else: else:
LOGGER.warning(f"{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.") LOGGER.warning(f'{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.')
return x return x
def get_labels(self): def get_labels(self):
self.label_files = img2label_paths(self.im_files) self.label_files = img2label_paths(self.im_files)
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache") cache_path = Path(self.label_files[0]).parent.with_suffix('.cache')
try: try:
cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict
assert cache["version"] == self.cache_version # matches current version assert cache['version'] == self.cache_version # matches current version
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
except (FileNotFoundError, AssertionError, AttributeError): except (FileNotFoundError, AssertionError, AttributeError):
cache, exists = self.cache_labels(cache_path), False # run cache ops cache, exists = self.cache_labels(cache_path), False # run cache ops
# Display cache # Display cache
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
if exists and LOCAL_RANK in {-1, 0}: if exists and LOCAL_RANK in {-1, 0}:
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt" d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
if cache["msgs"]: if cache['msgs']:
LOGGER.info("\n".join(cache["msgs"])) # display warnings LOGGER.info('\n'.join(cache['msgs'])) # display warnings
if nf == 0: # number of labels found if nf == 0: # number of labels found
raise FileNotFoundError(f"{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}") raise FileNotFoundError(f'{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}')
# Read cache # Read cache
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
labels = cache["labels"] labels = cache['labels']
self.im_files = [lb["im_file"] for lb in labels] # update im_files self.im_files = [lb['im_file'] for lb in labels] # update im_files
# Check if the dataset is all boxes or all segments # Check if the dataset is all boxes or all segments
len_cls = sum(len(lb["cls"]) for lb in labels) len_cls = sum(len(lb['cls']) for lb in labels)
len_boxes = sum(len(lb["bboxes"]) for lb in labels) len_boxes = sum(len(lb['bboxes']) for lb in labels)
len_segments = sum(len(lb["segments"]) for lb in labels) len_segments = sum(len(lb['segments']) for lb in labels)
if len_segments and len_boxes != len_segments: if len_segments and len_boxes != len_segments:
LOGGER.warning( LOGGER.warning(
f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, " f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, '
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. " f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. '
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.") 'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.')
for lb in labels: for lb in labels:
lb["segments"] = [] lb['segments'] = []
if len_cls == 0: if len_cls == 0:
raise ValueError(f"All labels empty in {cache_path}, can not start training without labels. {HELP_URL}") raise ValueError(f'All labels empty in {cache_path}, can not start training without labels. {HELP_URL}')
return labels return labels
# TODO: use hyp config to set all these augmentations # TODO: use hyp config to set all these augmentations
@ -142,7 +142,7 @@ class YOLODataset(BaseDataset):
else: else:
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)]) transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
transforms.append( transforms.append(
Format(bbox_format="xywh", Format(bbox_format='xywh',
normalize=True, normalize=True,
return_mask=self.use_segments, return_mask=self.use_segments,
return_keypoint=self.use_keypoints, return_keypoint=self.use_keypoints,
@ -161,12 +161,12 @@ class YOLODataset(BaseDataset):
"""custom your label format here""" """custom your label format here"""
# NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label # NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
# we can make it also support classification and semantic segmentation by add or remove some dict keys there. # we can make it also support classification and semantic segmentation by add or remove some dict keys there.
bboxes = label.pop("bboxes") bboxes = label.pop('bboxes')
segments = label.pop("segments") segments = label.pop('segments')
keypoints = label.pop("keypoints", None) keypoints = label.pop('keypoints', None)
bbox_format = label.pop("bbox_format") bbox_format = label.pop('bbox_format')
normalized = label.pop("normalized") normalized = label.pop('normalized')
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized) label['instances'] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
return label return label
@staticmethod @staticmethod
@ -176,15 +176,15 @@ class YOLODataset(BaseDataset):
values = list(zip(*[list(b.values()) for b in batch])) values = list(zip(*[list(b.values()) for b in batch]))
for i, k in enumerate(keys): for i, k in enumerate(keys):
value = values[i] value = values[i]
if k == "img": if k == 'img':
value = torch.stack(value, 0) value = torch.stack(value, 0)
if k in ["masks", "keypoints", "bboxes", "cls"]: if k in ['masks', 'keypoints', 'bboxes', 'cls']:
value = torch.cat(value, 0) value = torch.cat(value, 0)
new_batch[k] = value new_batch[k] = value
new_batch["batch_idx"] = list(new_batch["batch_idx"]) new_batch['batch_idx'] = list(new_batch['batch_idx'])
for i in range(len(new_batch["batch_idx"])): for i in range(len(new_batch['batch_idx'])):
new_batch["batch_idx"][i] += i # add target image index for build_targets() new_batch['batch_idx'][i] += i # add target image index for build_targets()
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0) new_batch['batch_idx'] = torch.cat(new_batch['batch_idx'], 0)
return new_batch return new_batch
@ -202,9 +202,9 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
super().__init__(root=root) super().__init__(root=root)
self.torch_transforms = classify_transforms(imgsz) self.torch_transforms = classify_transforms(imgsz)
self.album_transforms = classify_albumentations(augment, imgsz) if augment else None self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
self.cache_ram = cache is True or cache == "ram" self.cache_ram = cache is True or cache == 'ram'
self.cache_disk = cache == "disk" self.cache_disk = cache == 'disk'
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
def __getitem__(self, i): def __getitem__(self, i):
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
@ -217,7 +217,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
else: # read image else: # read image
im = cv2.imread(f) # BGR im = cv2.imread(f) # BGR
if self.album_transforms: if self.album_transforms:
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"] sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image']
else: else:
sample = self.torch_transforms(im) sample = self.torch_transforms(im)
return {'img': sample, 'cls': j} return {'img': sample, 'cls': j}

@ -25,15 +25,15 @@ class MixAndRectDataset:
labels = deepcopy(self.dataset[index]) labels = deepcopy(self.dataset[index])
for transform in self.dataset.transforms.tolist(): for transform in self.dataset.transforms.tolist():
# mosaic and mixup # mosaic and mixup
if hasattr(transform, "get_indexes"): if hasattr(transform, 'get_indexes'):
indexes = transform.get_indexes(self.dataset) indexes = transform.get_indexes(self.dataset)
if not isinstance(indexes, collections.abc.Sequence): if not isinstance(indexes, collections.abc.Sequence):
indexes = [indexes] indexes = [indexes]
mix_labels = [deepcopy(self.dataset[index]) for index in indexes] mix_labels = [deepcopy(self.dataset[index]) for index in indexes]
labels["mix_labels"] = mix_labels labels['mix_labels'] = mix_labels
if self.dataset.rect and isinstance(transform, LetterBox): if self.dataset.rect and isinstance(transform, LetterBox):
transform.new_shape = self.dataset.batch_shapes[self.dataset.batch[index]] transform.new_shape = self.dataset.batch_shapes[self.dataset.batch[index]]
labels = transform(labels) labels = transform(labels)
if "mix_labels" in labels: if 'mix_labels' in labels:
labels.pop("mix_labels") labels.pop('mix_labels')
return labels return labels

@ -55,4 +55,4 @@ download: |
for r in x[images == im]: for r in x[images == im]:
w, h = r[6], r[7] # image width, height w, h = r[6], r[7] # image width, height
xywh = xyxy2xywh(np.array([[r[1] / w, r[2] / h, r[3] / w, r[4] / h]]))[0] # instance xywh = xyxy2xywh(np.array([[r[1] / w, r[2] / h, r[3] / w, r[4] / h]]))[0] # instance
f.write(f"{cls} {xywh[0]:.5f} {xywh[1]:.5f} {xywh[2]:.5f} {xywh[3]:.5f}\n") # write label f.write(f"{cls} {xywh[0]:.5f} {xywh[1]:.5f} {xywh[2]:.5f} {xywh[3]:.5f}\n") # write label

@ -112,4 +112,4 @@ download: |
urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images
'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images 'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images
'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional) 'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional)
download(urls, dir=dir / 'images', threads=3) download(urls, dir=dir / 'images', threads=3)

@ -98,4 +98,4 @@ names:
# Download script/URL (optional) # Download script/URL (optional)
download: https://ultralytics.com/assets/coco128-seg.zip download: https://ultralytics.com/assets/coco128-seg.zip

@ -98,4 +98,4 @@ names:
# Download script/URL (optional) # Download script/URL (optional)
download: https://ultralytics.com/assets/coco128.zip download: https://ultralytics.com/assets/coco128.zip

@ -98,4 +98,4 @@ names:
# Download script/URL (optional) # Download script/URL (optional)
download: https://ultralytics.com/assets/coco8-seg.zip download: https://ultralytics.com/assets/coco8-seg.zip

@ -98,4 +98,4 @@ names:
# Download script/URL (optional) # Download script/URL (optional)
download: https://ultralytics.com/assets/coco8.zip download: https://ultralytics.com/assets/coco8.zip

@ -18,32 +18,32 @@ from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii
from ultralytics.yolo.utils.downloads import download, safe_download from ultralytics.yolo.utils.downloads import download, safe_download
from ultralytics.yolo.utils.ops import segments2boxes from ultralytics.yolo.utils.ops import segments2boxes
HELP_URL = "See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data" HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixes IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # include image suffixes
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv" # include video suffixes VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1)) RANK = int(os.getenv('RANK', -1))
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean
IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
# Get orientation exif tag # Get orientation exif tag
for orientation in ExifTags.TAGS.keys(): for orientation in ExifTags.TAGS.keys():
if ExifTags.TAGS[orientation] == "Orientation": if ExifTags.TAGS[orientation] == 'Orientation':
break break
def img2label_paths(img_paths): def img2label_paths(img_paths):
# Define label paths as a function of image paths # Define label paths as a function of image paths
sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths] return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
def get_hash(paths): def get_hash(paths):
# Returns a single hash value of a list of paths (files or dirs) # Returns a single hash value of a list of paths (files or dirs)
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
h = hashlib.sha256(str(size).encode()) # hash sizes h = hashlib.sha256(str(size).encode()) # hash sizes
h.update("".join(paths).encode()) # hash paths h.update(''.join(paths).encode()) # hash paths
return h.hexdigest() # return hash return h.hexdigest() # return hash
@ -61,21 +61,21 @@ def verify_image_label(args):
# Verify one image-label pair # Verify one image-label pair
im_file, lb_file, prefix, keypoint, num_cls = args im_file, lb_file, prefix, keypoint, num_cls = args
# number (missing, found, empty, corrupt), message, segments, keypoints # number (missing, found, empty, corrupt), message, segments, keypoints
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, '', [], None
try: try:
# verify images # verify images
im = Image.open(im_file) im = Image.open(im_file)
im.verify() # PIL verify im.verify() # PIL verify
shape = exif_size(im) # image size shape = exif_size(im) # image size
shape = (shape[1], shape[0]) # hw shape = (shape[1], shape[0]) # hw
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}" assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
if im.format.lower() in ("jpg", "jpeg"): if im.format.lower() in ('jpg', 'jpeg'):
with open(im_file, "rb") as f: with open(im_file, 'rb') as f:
f.seek(-2, 2) f.seek(-2, 2)
if f.read() != b"\xff\xd9": # corrupt JPEG if f.read() != b'\xff\xd9': # corrupt JPEG
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100) ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved" msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
# verify labels # verify labels
if os.path.isfile(lb_file): if os.path.isfile(lb_file):
@ -90,31 +90,31 @@ def verify_image_label(args):
nl = len(lb) nl = len(lb)
if nl: if nl:
if keypoint: if keypoint:
assert lb.shape[1] == 56, "labels require 56 columns each" assert lb.shape[1] == 56, 'labels require 56 columns each'
assert (lb[:, 5::3] <= 1).all(), "non-normalized or out of bounds coordinate labels" assert (lb[:, 5::3] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
assert (lb[:, 6::3] <= 1).all(), "non-normalized or out of bounds coordinate labels" assert (lb[:, 6::3] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
kpts = np.zeros((lb.shape[0], 39)) kpts = np.zeros((lb.shape[0], 39))
for i in range(len(lb)): for i in range(len(lb)):
kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5, 3)) # remove occlusion param from GT kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5, 3)) # remove occlusion param from GT
kpts[i] = np.hstack((lb[i, :5], kpt)) kpts[i] = np.hstack((lb[i, :5], kpt))
lb = kpts lb = kpts
assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion parameter" assert lb.shape[1] == 39, 'labels require 39 columns each after removing occlusion parameter'
else: else:
assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected" assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
assert (lb[:, 1:] <= 1).all(), \ assert (lb[:, 1:] <= 1).all(), \
f"non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}" f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
# All labels # All labels
max_cls = int(lb[:, 0].max()) # max label count max_cls = int(lb[:, 0].max()) # max label count
assert max_cls <= num_cls, \ assert max_cls <= num_cls, \
f'Label class {max_cls} exceeds dataset class count {num_cls}. ' \ f'Label class {max_cls} exceeds dataset class count {num_cls}. ' \
f'Possible class labels are 0-{num_cls - 1}' f'Possible class labels are 0-{num_cls - 1}'
assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}" assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
_, i = np.unique(lb, axis=0, return_index=True) _, i = np.unique(lb, axis=0, return_index=True)
if len(i) < nl: # duplicate row check if len(i) < nl: # duplicate row check
lb = lb[i] # remove duplicates lb = lb[i] # remove duplicates
if segments: if segments:
segments = [segments[x] for x in i] segments = [segments[x] for x in i]
msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed" msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed'
else: else:
ne = 1 # label empty ne = 1 # label empty
lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32) lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
@ -127,7 +127,7 @@ def verify_image_label(args):
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
except Exception as e: except Exception as e:
nc = 1 nc = 1
msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}" msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
return [None, None, None, None, None, nm, nf, ne, nc, msg] return [None, None, None, None, None, nm, nf, ne, nc, msg]
@ -248,8 +248,8 @@ def check_det_dataset(dataset, autodownload=True):
else: # python script else: # python script
r = exec(s, {'yaml': data}) # return None r = exec(s, {'yaml': data}) # return None
dt = f'({round(time.time() - t, 1)}s)' dt = f'({round(time.time() - t, 1)}s)'
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt}" s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt}'
LOGGER.info(f"Dataset download {s}\n") LOGGER.info(f'Dataset download {s}\n')
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts
return data # dictionary return data # dictionary
@ -284,9 +284,9 @@ def check_cls_dataset(dataset: str):
download(url, dir=data_dir.parent) download(url, dir=data_dir.parent)
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n" s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
LOGGER.info(s) LOGGER.info(s)
train_set = data_dir / "train" train_set = data_dir / 'train'
test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
names = dict(enumerate(sorted(names))) names = dict(enumerate(sorted(names)))
return {"train": train_set, "val": test_set, "nc": nc, "names": names} return {'train': train_set, 'val': test_set, 'nc': nc, 'names': names}

@ -144,7 +144,7 @@ class Exporter:
@smart_inference_mode() @smart_inference_mode()
def __call__(self, model=None): def __call__(self, model=None):
self.run_callbacks("on_export_start") self.run_callbacks('on_export_start')
t = time.time() t = time.time()
format = self.args.format.lower() # to lowercase format = self.args.format.lower() # to lowercase
if format in {'tensorrt', 'trt'}: # engine aliases if format in {'tensorrt', 'trt'}: # engine aliases
@ -207,7 +207,7 @@ class Exporter:
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y) self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
self.pretty_name = self.file.stem.replace('yolo', 'YOLO') self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
self.metadata = { self.metadata = {
'description': f"Ultralytics {self.pretty_name} model trained on {self.args.data}", 'description': f'Ultralytics {self.pretty_name} model trained on {self.args.data}',
'author': 'Ultralytics', 'author': 'Ultralytics',
'license': 'GPL-3.0 https://ultralytics.com/license', 'license': 'GPL-3.0 https://ultralytics.com/license',
'version': __version__, 'version': __version__,
@ -215,7 +215,7 @@ class Exporter:
'names': model.names} # model metadata 'names': model.names} # model metadata
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and " LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and "
f"output shape(s) {self.output_shape} ({file_size(file):.1f} MB)") f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)')
# Exports # Exports
f = [''] * len(fmts) # exported filenames f = [''] * len(fmts) # exported filenames
@ -259,15 +259,15 @@ class Exporter:
s = '' if square else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " \ s = '' if square else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " \
f"work. Use export 'imgsz={max(self.imgsz)}' if val is required." f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(' ', '') imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(' ', '')
data = f"data={self.args.data}" if model.task == 'segment' and format == 'pb' else '' data = f'data={self.args.data}' if model.task == 'segment' and format == 'pb' else ''
LOGGER.info( LOGGER.info(
f'\nExport complete ({time.time() - t:.1f}s)' f'\nExport complete ({time.time() - t:.1f}s)'
f"\nResults saved to {colorstr('bold', file.parent.resolve())}" f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
f"\nPredict: yolo task={model.task} mode=predict model={f} imgsz={imgsz} {data}" f'\nPredict: yolo task={model.task} mode=predict model={f} imgsz={imgsz} {data}'
f"\nValidate: yolo task={model.task} mode=val model={f} imgsz={imgsz} data={self.args.data} {s}" f'\nValidate: yolo task={model.task} mode=val model={f} imgsz={imgsz} data={self.args.data} {s}'
f"\nVisualize: https://netron.app") f'\nVisualize: https://netron.app')
self.run_callbacks("on_export_end") self.run_callbacks('on_export_end')
return f # return list of exported files/dirs return f # return list of exported files/dirs
@try_export @try_export
@ -277,7 +277,7 @@ class Exporter:
f = self.file.with_suffix('.torchscript') f = self.file.with_suffix('.torchscript')
ts = torch.jit.trace(self.model, self.im, strict=False) ts = torch.jit.trace(self.model, self.im, strict=False)
d = {"shape": self.im.shape, "stride": int(max(self.model.stride)), "names": self.model.names} d = {'shape': self.im.shape, 'stride': int(max(self.model.stride)), 'names': self.model.names}
extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap() extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
LOGGER.info(f'{prefix} optimizing for mobile...') LOGGER.info(f'{prefix} optimizing for mobile...')
@ -354,7 +354,7 @@ class Exporter:
ov_model = mo.convert_model(f_onnx, ov_model = mo.convert_model(f_onnx,
model_name=self.pretty_name, model_name=self.pretty_name,
framework="onnx", framework='onnx',
compress_to_fp16=self.args.half) # export compress_to_fp16=self.args.half) # export
ov.serialize(ov_model, f_ov) # save ov.serialize(ov_model, f_ov) # save
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
@ -471,7 +471,7 @@ class Exporter:
if self.args.dynamic: if self.args.dynamic:
shape = self.im.shape shape = self.im.shape
if shape[0] <= 1: if shape[0] <= 1:
LOGGER.warning(f"{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument") LOGGER.warning(f'{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument')
profile = builder.create_optimization_profile() profile = builder.create_optimization_profile()
for inp in inputs: for inp in inputs:
profile.set_shape(inp.name, (1, *shape[1:]), (max(1, shape[0] // 2), *shape[1:]), shape) profile.set_shape(inp.name, (1, *shape[1:]), (max(1, shape[0] // 2), *shape[1:]), shape)
@ -509,8 +509,8 @@ class Exporter:
except ImportError: except ImportError:
check_requirements(f"tensorflow{'' if CUDA else '-macos' if MACOS else '-cpu' if LINUX else ''}") check_requirements(f"tensorflow{'' if CUDA else '-macos' if MACOS else '-cpu' if LINUX else ''}")
import tensorflow as tf # noqa import tensorflow as tf # noqa
check_requirements(("onnx", "onnx2tf", "sng4onnx", "onnxsim", "onnx_graphsurgeon", "tflite_support"), check_requirements(('onnx', 'onnx2tf', 'sng4onnx', 'onnxsim', 'onnx_graphsurgeon', 'tflite_support'),
cmds="--extra-index-url https://pypi.ngc.nvidia.com") cmds='--extra-index-url https://pypi.ngc.nvidia.com')
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
f = str(self.file).replace(self.file.suffix, '_saved_model') f = str(self.file).replace(self.file.suffix, '_saved_model')
@ -632,7 +632,7 @@ class Exporter:
converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS) converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
tflite_model = converter.convert() tflite_model = converter.convert()
open(f, "wb").write(tflite_model) open(f, 'wb').write(tflite_model)
return f, None return f, None
@try_export @try_export
@ -656,7 +656,7 @@ class Exporter:
LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...') LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
f = str(tflite_model).replace('.tflite', '_edgetpu.tflite') # Edge TPU model f = str(tflite_model).replace('.tflite', '_edgetpu.tflite') # Edge TPU model
cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {self.file.parent} {tflite_model}" cmd = f'edgetpu_compiler -s -d -k 10 --out_dir {self.file.parent} {tflite_model}'
subprocess.run(cmd.split(), check=True) subprocess.run(cmd.split(), check=True)
self._add_tflite_metadata(f) self._add_tflite_metadata(f)
return f, None return f, None
@ -707,8 +707,8 @@ class Exporter:
# Creates input info. # Creates input info.
input_meta = _metadata_fb.TensorMetadataT() input_meta = _metadata_fb.TensorMetadataT()
input_meta.name = "image" input_meta.name = 'image'
input_meta.description = "Input image to be detected." input_meta.description = 'Input image to be detected.'
input_meta.content = _metadata_fb.ContentT() input_meta.content = _metadata_fb.ContentT()
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT() input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB
@ -716,8 +716,8 @@ class Exporter:
# Creates output info. # Creates output info.
output_meta = _metadata_fb.TensorMetadataT() output_meta = _metadata_fb.TensorMetadataT()
output_meta.name = "output" output_meta.name = 'output'
output_meta.description = "Coordinates of detected objects, class labels, and confidence score." output_meta.description = 'Coordinates of detected objects, class labels, and confidence score.'
# Label file # Label file
tmp_file = Path('/tmp/meta.txt') tmp_file = Path('/tmp/meta.txt')
@ -868,8 +868,8 @@ class Exporter:
def export(cfg=DEFAULT_CFG): def export(cfg=DEFAULT_CFG):
cfg.model = cfg.model or "yolov8n.yaml" cfg.model = cfg.model or 'yolov8n.yaml'
cfg.format = cfg.format or "torchscript" cfg.format = cfg.format or 'torchscript'
# exporter = Exporter(cfg) # exporter = Exporter(cfg)
# #
@ -888,7 +888,7 @@ def export(cfg=DEFAULT_CFG):
model.export(**vars(cfg)) model.export(**vars(cfg))
if __name__ == "__main__": if __name__ == '__main__':
""" """
CLI: CLI:
yolo mode=export model=yolov8n.yaml format=onnx yolo mode=export model=yolov8n.yaml format=onnx

@ -16,13 +16,13 @@ from ultralytics.yolo.utils.torch_utils import smart_inference_mode
# Map head to model, trainer, validator, and predictor classes # Map head to model, trainer, validator, and predictor classes
MODEL_MAP = { MODEL_MAP = {
"classify": [ 'classify': [
ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator', ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator',
'yolo.TYPE.classify.ClassificationPredictor'], 'yolo.TYPE.classify.ClassificationPredictor'],
"detect": [ 'detect': [
DetectionModel, 'yolo.TYPE.detect.DetectionTrainer', 'yolo.TYPE.detect.DetectionValidator', DetectionModel, 'yolo.TYPE.detect.DetectionTrainer', 'yolo.TYPE.detect.DetectionValidator',
'yolo.TYPE.detect.DetectionPredictor'], 'yolo.TYPE.detect.DetectionPredictor'],
"segment": [ 'segment': [
SegmentationModel, 'yolo.TYPE.segment.SegmentationTrainer', 'yolo.TYPE.segment.SegmentationValidator', SegmentationModel, 'yolo.TYPE.segment.SegmentationTrainer', 'yolo.TYPE.segment.SegmentationValidator',
'yolo.TYPE.segment.SegmentationPredictor']} 'yolo.TYPE.segment.SegmentationPredictor']}
@ -34,7 +34,7 @@ class YOLO:
A python interface which emulates a model-like behaviour by wrapping trainers. A python interface which emulates a model-like behaviour by wrapping trainers.
""" """
def __init__(self, model='yolov8n.pt', type="v8") -> None: def __init__(self, model='yolov8n.pt', type='v8') -> None:
""" """
Initializes the YOLO object. Initializes the YOLO object.
@ -94,7 +94,7 @@ class YOLO:
suffix = Path(weights).suffix suffix = Path(weights).suffix
if suffix == '.pt': if suffix == '.pt':
self.model, self.ckpt = attempt_load_one_weight(weights) self.model, self.ckpt = attempt_load_one_weight(weights)
self.task = self.model.args["task"] self.task = self.model.args['task']
self.overrides = self.model.args self.overrides = self.model.args
self._reset_ckpt_args(self.overrides) self._reset_ckpt_args(self.overrides)
else: else:
@ -111,7 +111,7 @@ class YOLO:
""" """
if not isinstance(self.model, nn.Module): if not isinstance(self.model, nn.Module):
raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. " raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
f"PyTorch models can be used to train, val, predict and export, i.e. " f'PyTorch models can be used to train, val, predict and export, i.e. '
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only " f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.") f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
@ -155,11 +155,11 @@ class YOLO:
(List[ultralytics.yolo.engine.results.Results]): The prediction results. (List[ultralytics.yolo.engine.results.Results]): The prediction results.
""" """
overrides = self.overrides.copy() overrides = self.overrides.copy()
overrides["conf"] = 0.25 overrides['conf'] = 0.25
overrides.update(kwargs) overrides.update(kwargs)
overrides["mode"] = kwargs.get("mode", "predict") overrides['mode'] = kwargs.get('mode', 'predict')
assert overrides["mode"] in ['track', 'predict'] assert overrides['mode'] in ['track', 'predict']
overrides["save"] = kwargs.get("save", False) # not save files by default overrides['save'] = kwargs.get('save', False) # not save files by default
if not self.predictor: if not self.predictor:
self.predictor = self.PredictorClass(overrides=overrides) self.predictor = self.PredictorClass(overrides=overrides)
self.predictor.setup_model(model=self.model) self.predictor.setup_model(model=self.model)
@ -173,7 +173,7 @@ class YOLO:
from ultralytics.tracker.track import register_tracker from ultralytics.tracker.track import register_tracker
register_tracker(self) register_tracker(self)
# bytetrack-based method needs low confidence predictions as input # bytetrack-based method needs low confidence predictions as input
conf = kwargs.get("conf") or 0.1 conf = kwargs.get('conf') or 0.1
kwargs['conf'] = conf kwargs['conf'] = conf
kwargs['mode'] = 'track' kwargs['mode'] = 'track'
return self.predict(source=source, stream=stream, **kwargs) return self.predict(source=source, stream=stream, **kwargs)
@ -188,9 +188,9 @@ class YOLO:
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
""" """
overrides = self.overrides.copy() overrides = self.overrides.copy()
overrides["rect"] = True # rect batches as default overrides['rect'] = True # rect batches as default
overrides.update(kwargs) overrides.update(kwargs)
overrides["mode"] = "val" overrides['mode'] = 'val'
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.data = data or args.data args.data = data or args.data
args.task = self.task args.task = self.task
@ -234,18 +234,18 @@ class YOLO:
self._check_is_pytorch_model() self._check_is_pytorch_model()
overrides = self.overrides.copy() overrides = self.overrides.copy()
overrides.update(kwargs) overrides.update(kwargs)
if kwargs.get("cfg"): if kwargs.get('cfg'):
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.") LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
overrides = yaml_load(check_yaml(kwargs["cfg"]), append_filename=True) overrides = yaml_load(check_yaml(kwargs['cfg']), append_filename=True)
overrides["task"] = self.task overrides['task'] = self.task
overrides["mode"] = "train" overrides['mode'] = 'train'
if not overrides.get("data"): if not overrides.get('data'):
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'") raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
if overrides.get("resume"): if overrides.get('resume'):
overrides["resume"] = self.ckpt_path overrides['resume'] = self.ckpt_path
self.trainer = self.TrainerClass(overrides=overrides) self.trainer = self.TrainerClass(overrides=overrides)
if not overrides.get("resume"): # manually set model only if not resuming if not overrides.get('resume'): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml) self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model self.model = self.trainer.model
self.trainer.train() self.trainer.train()
@ -267,9 +267,9 @@ class YOLO:
def _assign_ops_from_task(self): def _assign_ops_from_task(self):
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[self.task] model_class, train_lit, val_lit, pred_lit = MODEL_MAP[self.task]
trainer_class = eval(train_lit.replace("TYPE", f"{self.type}")) trainer_class = eval(train_lit.replace('TYPE', f'{self.type}'))
validator_class = eval(val_lit.replace("TYPE", f"{self.type}")) validator_class = eval(val_lit.replace('TYPE', f'{self.type}'))
predictor_class = eval(pred_lit.replace("TYPE", f"{self.type}")) predictor_class = eval(pred_lit.replace('TYPE', f'{self.type}'))
return model_class, trainer_class, validator_class, predictor_class return model_class, trainer_class, validator_class, predictor_class
@property @property
@ -292,7 +292,7 @@ class YOLO:
Returns metrics if computed Returns metrics if computed
""" """
if not self.metrics_data: if not self.metrics_data:
LOGGER.info("No metrics data found! Run training or validation operation first.") LOGGER.info('No metrics data found! Run training or validation operation first.')
return self.metrics_data return self.metrics_data

@ -72,7 +72,7 @@ class BasePredictor:
""" """
self.args = get_cfg(cfg, overrides) self.args = get_cfg(cfg, overrides)
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f"{self.args.mode}" name = self.args.name or f'{self.args.mode}'
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok) self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
if self.args.conf is None: if self.args.conf is None:
self.args.conf = 0.25 # default conf=0.25 self.args.conf = 0.25 # default conf=0.25
@ -97,10 +97,10 @@ class BasePredictor:
pass pass
def get_annotator(self, img): def get_annotator(self, img):
raise NotImplementedError("get_annotator function needs to be implemented") raise NotImplementedError('get_annotator function needs to be implemented')
def write_results(self, results, batch, print_string): def write_results(self, results, batch, print_string):
raise NotImplementedError("print_results function needs to be implemented") raise NotImplementedError('print_results function needs to be implemented')
def postprocess(self, preds, img, orig_img): def postprocess(self, preds, img, orig_img):
return preds return preds
@ -135,7 +135,7 @@ class BasePredictor:
def stream_inference(self, source=None, model=None): def stream_inference(self, source=None, model=None):
if self.args.verbose: if self.args.verbose:
LOGGER.info("") LOGGER.info('')
# setup model # setup model
if not self.model: if not self.model:
@ -152,9 +152,9 @@ class BasePredictor:
self.done_warmup = True self.done_warmup = True
self.seen, self.windows, self.dt, self.batch = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()), None self.seen, self.windows, self.dt, self.batch = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()), None
self.run_callbacks("on_predict_start") self.run_callbacks('on_predict_start')
for batch in self.dataset: for batch in self.dataset:
self.run_callbacks("on_predict_batch_start") self.run_callbacks('on_predict_batch_start')
self.batch = batch self.batch = batch
path, im, im0s, vid_cap, s = batch path, im, im0s, vid_cap, s = batch
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
@ -170,7 +170,7 @@ class BasePredictor:
# postprocess # postprocess
with self.dt[2]: with self.dt[2]:
self.results = self.postprocess(preds, im, im0s) self.results = self.postprocess(preds, im, im0s)
self.run_callbacks("on_predict_postprocess_end") self.run_callbacks('on_predict_postprocess_end')
# visualize, save, write results # visualize, save, write results
for i in range(len(im)): for i in range(len(im)):
@ -186,7 +186,7 @@ class BasePredictor:
if self.args.save: if self.args.save:
self.save_preds(vid_cap, i, str(self.save_dir / p.name)) self.save_preds(vid_cap, i, str(self.save_dir / p.name))
self.run_callbacks("on_predict_batch_end") self.run_callbacks('on_predict_batch_end')
yield from self.results yield from self.results
# Print time (inference-only) # Print time (inference-only)
@ -207,7 +207,7 @@ class BasePredictor:
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else '' s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}") LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
self.run_callbacks("on_predict_end") self.run_callbacks('on_predict_end')
def setup_model(self, model): def setup_model(self, model):
device = select_device(self.args.device) device = select_device(self.args.device)

@ -36,7 +36,7 @@ class Results:
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
self.probs = probs if probs is not None else None self.probs = probs if probs is not None else None
self.names = names self.names = names
self.comp = ["boxes", "masks", "probs"] self.comp = ['boxes', 'masks', 'probs']
def pandas(self): def pandas(self):
pass pass
@ -97,7 +97,7 @@ class Results:
return len(getattr(self, item)) return len(getattr(self, item))
def __str__(self): def __str__(self):
str_out = "" str_out = ''
for item in self.comp: for item in self.comp:
if getattr(self, item) is None: if getattr(self, item) is None:
continue continue
@ -105,7 +105,7 @@ class Results:
return str_out return str_out
def __repr__(self): def __repr__(self):
str_out = "" str_out = ''
for item in self.comp: for item in self.comp:
if getattr(self, item) is None: if getattr(self, item) is None:
continue continue
@ -187,7 +187,7 @@ class Boxes:
if boxes.ndim == 1: if boxes.ndim == 1:
boxes = boxes[None, :] boxes = boxes[None, :]
n = boxes.shape[-1] n = boxes.shape[-1]
assert n in {6, 7}, f"expected `n` in [6, 7], but got {n}" # xyxy, (track_id), conf, cls assert n in {6, 7}, f'expected `n` in [6, 7], but got {n}' # xyxy, (track_id), conf, cls
# TODO # TODO
self.is_track = n == 7 self.is_track = n == 7
self.boxes = boxes self.boxes = boxes
@ -268,8 +268,8 @@ class Boxes:
return self.boxes.__str__() return self.boxes.__str__()
def __repr__(self): def __repr__(self):
return (f"Ultralytics YOLO {self.__class__} masks\n" + f"type: {type(self.boxes)}\n" + return (f'Ultralytics YOLO {self.__class__} masks\n' + f'type: {type(self.boxes)}\n' +
f"shape: {self.boxes.shape}\n" + f"dtype: {self.boxes.dtype}\n + {self.boxes.__repr__()}") f'shape: {self.boxes.shape}\n' + f'dtype: {self.boxes.dtype}\n + {self.boxes.__repr__()}')
def __getitem__(self, idx): def __getitem__(self, idx):
boxes = self.boxes[idx] boxes = self.boxes[idx]
@ -353,8 +353,8 @@ class Masks:
return self.masks.__str__() return self.masks.__str__()
def __repr__(self): def __repr__(self):
return (f"Ultralytics YOLO {self.__class__} masks\n" + f"type: {type(self.masks)}\n" + return (f'Ultralytics YOLO {self.__class__} masks\n' + f'type: {type(self.masks)}\n' +
f"shape: {self.masks.shape}\n" + f"dtype: {self.masks.dtype}\n + {self.masks.__repr__()}") f'shape: {self.masks.shape}\n' + f'dtype: {self.masks.dtype}\n + {self.masks.__repr__()}')
def __getitem__(self, idx): def __getitem__(self, idx):
masks = self.masks[idx] masks = self.masks[idx]
@ -374,19 +374,19 @@ class Masks:
""") """)
if __name__ == "__main__": if __name__ == '__main__':
# test examples # test examples
results = Results(boxes=torch.randn((2, 6)), masks=torch.randn((2, 160, 160)), orig_shape=[640, 640]) results = Results(boxes=torch.randn((2, 6)), masks=torch.randn((2, 160, 160)), orig_shape=[640, 640])
results = results.cuda() results = results.cuda()
print("--cuda--pass--") print('--cuda--pass--')
results = results.cpu() results = results.cpu()
print("--cpu--pass--") print('--cpu--pass--')
results = results.to("cuda:0") results = results.to('cuda:0')
print("--to-cuda--pass--") print('--to-cuda--pass--')
results = results.to("cpu") results = results.to('cpu')
print("--to-cpu--pass--") print('--to-cpu--pass--')
results = results.numpy() results = results.numpy()
print("--numpy--pass--") print('--numpy--pass--')
# box = Boxes(boxes=torch.randn((2, 6)), orig_shape=[5, 5]) # box = Boxes(boxes=torch.randn((2, 6)), orig_shape=[5, 5])
# box = box.cuda() # box = box.cuda()
# box = box.cpu() # box = box.cpu()

@ -90,7 +90,7 @@ class BaseTrainer:
# Dirs # Dirs
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f"{self.args.mode}" name = self.args.name or f'{self.args.mode}'
if hasattr(self.args, 'save_dir'): if hasattr(self.args, 'save_dir'):
self.save_dir = Path(self.args.save_dir) self.save_dir = Path(self.args.save_dir)
else: else:
@ -121,7 +121,7 @@ class BaseTrainer:
try: try:
if self.args.task == 'classify': if self.args.task == 'classify':
self.data = check_cls_dataset(self.args.data) self.data = check_cls_dataset(self.args.data)
elif self.args.data.endswith(".yaml") or self.args.task in ('detect', 'segment'): elif self.args.data.endswith('.yaml') or self.args.task in ('detect', 'segment'):
self.data = check_det_dataset(self.args.data) self.data = check_det_dataset(self.args.data)
if 'yaml_file' in self.data: if 'yaml_file' in self.data:
self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
@ -175,7 +175,7 @@ class BaseTrainer:
world_size = 0 world_size = 0
# Run subprocess if DDP training, else train normally # Run subprocess if DDP training, else train normally
if world_size > 1 and "LOCAL_RANK" not in os.environ: if world_size > 1 and 'LOCAL_RANK' not in os.environ:
cmd, file = generate_ddp_command(world_size, self) # security vulnerability in Snyk scans cmd, file = generate_ddp_command(world_size, self) # security vulnerability in Snyk scans
try: try:
subprocess.run(cmd, check=True) subprocess.run(cmd, check=True)
@ -191,15 +191,15 @@ class BaseTrainer:
# os.environ['MASTER_PORT'] = '9020' # os.environ['MASTER_PORT'] = '9020'
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
self.device = torch.device('cuda', rank) self.device = torch.device('cuda', rank)
self.console.info(f"DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}") self.console.info(f'DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}')
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size) dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo', rank=rank, world_size=world_size)
def _setup_train(self, rank, world_size): def _setup_train(self, rank, world_size):
""" """
Builds dataloaders and optimizer on correct rank process. Builds dataloaders and optimizer on correct rank process.
""" """
# model # model
self.run_callbacks("on_pretrain_routine_start") self.run_callbacks('on_pretrain_routine_start')
ckpt = self.setup_model() ckpt = self.setup_model()
self.model = self.model.to(self.device) self.model = self.model.to(self.device)
self.set_model_attributes() self.set_model_attributes()
@ -234,16 +234,16 @@ class BaseTrainer:
# dataloaders # dataloaders
batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode="train") self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode='train')
if rank in {0, -1}: if rank in {0, -1}:
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val") self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
self.validator = self.get_validator() self.validator = self.get_validator()
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val") metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()? self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
self.ema = ModelEMA(self.model) self.ema = ModelEMA(self.model)
self.resume_training(ckpt) self.resume_training(ckpt)
self.scheduler.last_epoch = self.start_epoch - 1 # do not move self.scheduler.last_epoch = self.start_epoch - 1 # do not move
self.run_callbacks("on_pretrain_routine_end") self.run_callbacks('on_pretrain_routine_end')
def _do_train(self, rank=-1, world_size=1): def _do_train(self, rank=-1, world_size=1):
if world_size > 1: if world_size > 1:
@ -257,24 +257,24 @@ class BaseTrainer:
nb = len(self.train_loader) # number of batches nb = len(self.train_loader) # number of batches
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
last_opt_step = -1 last_opt_step = -1
self.run_callbacks("on_train_start") self.run_callbacks('on_train_start')
self.log(f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n" self.log(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n' f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
f"Logging results to {colorstr('bold', self.save_dir)}\n" f"Logging results to {colorstr('bold', self.save_dir)}\n"
f"Starting training for {self.epochs} epochs...") f'Starting training for {self.epochs} epochs...')
if self.args.close_mosaic: if self.args.close_mosaic:
base_idx = (self.epochs - self.args.close_mosaic) * nb base_idx = (self.epochs - self.args.close_mosaic) * nb
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2]) self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
for epoch in range(self.start_epoch, self.epochs): for epoch in range(self.start_epoch, self.epochs):
self.epoch = epoch self.epoch = epoch
self.run_callbacks("on_train_epoch_start") self.run_callbacks('on_train_epoch_start')
self.model.train() self.model.train()
if rank != -1: if rank != -1:
self.train_loader.sampler.set_epoch(epoch) self.train_loader.sampler.set_epoch(epoch)
pbar = enumerate(self.train_loader) pbar = enumerate(self.train_loader)
# Update dataloader attributes (optional) # Update dataloader attributes (optional)
if epoch == (self.epochs - self.args.close_mosaic): if epoch == (self.epochs - self.args.close_mosaic):
self.console.info("Closing dataloader mosaic") self.console.info('Closing dataloader mosaic')
if hasattr(self.train_loader.dataset, 'mosaic'): if hasattr(self.train_loader.dataset, 'mosaic'):
self.train_loader.dataset.mosaic = False self.train_loader.dataset.mosaic = False
if hasattr(self.train_loader.dataset, 'close_mosaic'): if hasattr(self.train_loader.dataset, 'close_mosaic'):
@ -286,7 +286,7 @@ class BaseTrainer:
self.tloss = None self.tloss = None
self.optimizer.zero_grad() self.optimizer.zero_grad()
for i, batch in pbar: for i, batch in pbar:
self.run_callbacks("on_train_batch_start") self.run_callbacks('on_train_batch_start')
# Warmup # Warmup
ni = i + nb * epoch ni = i + nb * epoch
if ni <= nw: if ni <= nw:
@ -302,7 +302,7 @@ class BaseTrainer:
# Forward # Forward
with torch.cuda.amp.autocast(self.amp): with torch.cuda.amp.autocast(self.amp):
batch = self.preprocess_batch(batch) batch = self.preprocess_batch(batch)
preds = self.model(batch["img"]) preds = self.model(batch['img'])
self.loss, self.loss_items = self.criterion(preds, batch) self.loss, self.loss_items = self.criterion(preds, batch)
if rank != -1: if rank != -1:
self.loss *= world_size self.loss *= world_size
@ -324,17 +324,17 @@ class BaseTrainer:
if rank in {-1, 0}: if rank in {-1, 0}:
pbar.set_description( pbar.set_description(
('%11s' * 2 + '%11.4g' * (2 + loss_len)) % ('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])) (f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1]))
self.run_callbacks('on_batch_end') self.run_callbacks('on_batch_end')
if self.args.plots and ni in self.plot_idx: if self.args.plots and ni in self.plot_idx:
self.plot_training_samples(batch, ni) self.plot_training_samples(batch, ni)
self.run_callbacks("on_train_batch_end") self.run_callbacks('on_train_batch_end')
self.lr = {f"lr/pg{ir}": x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
self.scheduler.step() self.scheduler.step()
self.run_callbacks("on_train_epoch_end") self.run_callbacks('on_train_epoch_end')
if rank in {-1, 0}: if rank in {-1, 0}:
@ -355,7 +355,7 @@ class BaseTrainer:
tnow = time.time() tnow = time.time()
self.epoch_time = tnow - self.epoch_time_start self.epoch_time = tnow - self.epoch_time_start
self.epoch_time_start = tnow self.epoch_time_start = tnow
self.run_callbacks("on_fit_epoch_end") self.run_callbacks('on_fit_epoch_end')
# Early Stopping # Early Stopping
if RANK != -1: # if DDP training if RANK != -1: # if DDP training
@ -402,7 +402,7 @@ class BaseTrainer:
""" """
Get train, val path from data dict if it exists. Returns None if data format is not recognized. Get train, val path from data dict if it exists. Returns None if data format is not recognized.
""" """
return data["train"], data.get("val") or data.get("test") return data['train'], data.get('val') or data.get('test')
def setup_model(self): def setup_model(self):
""" """
@ -413,9 +413,9 @@ class BaseTrainer:
model, weights = self.model, None model, weights = self.model, None
ckpt = None ckpt = None
if str(model).endswith(".pt"): if str(model).endswith('.pt'):
weights, ckpt = attempt_load_one_weight(model) weights, ckpt = attempt_load_one_weight(model)
cfg = ckpt["model"].yaml cfg = ckpt['model'].yaml
else: else:
cfg = model cfg = model
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights) self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
@ -441,7 +441,7 @@ class BaseTrainer:
Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key. Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
""" """
metrics = self.validator(self) metrics = self.validator(self)
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
if not self.best_fitness or self.best_fitness < fitness: if not self.best_fitness or self.best_fitness < fitness:
self.best_fitness = fitness self.best_fitness = fitness
return metrics, fitness return metrics, fitness
@ -462,38 +462,38 @@ class BaseTrainer:
raise NotImplementedError("This task trainer doesn't support loading cfg files") raise NotImplementedError("This task trainer doesn't support loading cfg files")
def get_validator(self): def get_validator(self):
raise NotImplementedError("get_validator function not implemented in trainer") raise NotImplementedError('get_validator function not implemented in trainer')
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
""" """
Returns dataloader derived from torch.data.Dataloader. Returns dataloader derived from torch.data.Dataloader.
""" """
raise NotImplementedError("get_dataloader function not implemented in trainer") raise NotImplementedError('get_dataloader function not implemented in trainer')
def criterion(self, preds, batch): def criterion(self, preds, batch):
""" """
Returns loss and individual loss items as Tensor. Returns loss and individual loss items as Tensor.
""" """
raise NotImplementedError("criterion function not implemented in trainer") raise NotImplementedError('criterion function not implemented in trainer')
def label_loss_items(self, loss_items=None, prefix="train"): def label_loss_items(self, loss_items=None, prefix='train'):
""" """
Returns a loss dict with labelled training loss items tensor Returns a loss dict with labelled training loss items tensor
""" """
# Not needed for classification but necessary for segmentation & detection # Not needed for classification but necessary for segmentation & detection
return {"loss": loss_items} if loss_items is not None else ["loss"] return {'loss': loss_items} if loss_items is not None else ['loss']
def set_model_attributes(self): def set_model_attributes(self):
""" """
To set or update model parameters before training. To set or update model parameters before training.
""" """
self.model.names = self.data["names"] self.model.names = self.data['names']
def build_targets(self, preds, targets): def build_targets(self, preds, targets):
pass pass
def progress_string(self): def progress_string(self):
return "" return ''
# TODO: may need to put these following functions into callback # TODO: may need to put these following functions into callback
def plot_training_samples(self, batch, ni): def plot_training_samples(self, batch, ni):
@ -529,7 +529,7 @@ class BaseTrainer:
self.args = get_cfg(attempt_load_weights(last).args) self.args = get_cfg(attempt_load_weights(last).args)
self.args.model, resume = str(last), True # reinstate self.args.model, resume = str(last), True # reinstate
except Exception as e: except Exception as e:
raise FileNotFoundError("Resume checkpoint not found. Please pass a valid checkpoint to resume from, " raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '
"i.e. 'yolo train resume model=path/to/last.pt'") from e "i.e. 'yolo train resume model=path/to/last.pt'") from e
self.resume = resume self.resume = resume
@ -557,7 +557,7 @@ class BaseTrainer:
self.best_fitness = best_fitness self.best_fitness = best_fitness
self.start_epoch = start_epoch self.start_epoch = start_epoch
if start_epoch > (self.epochs - self.args.close_mosaic): if start_epoch > (self.epochs - self.args.close_mosaic):
self.console.info("Closing dataloader mosaic") self.console.info('Closing dataloader mosaic')
if hasattr(self.train_loader.dataset, 'mosaic'): if hasattr(self.train_loader.dataset, 'mosaic'):
self.train_loader.dataset.mosaic = False self.train_loader.dataset.mosaic = False
if hasattr(self.train_loader.dataset, 'close_mosaic'): if hasattr(self.train_loader.dataset, 'close_mosaic'):
@ -602,5 +602,5 @@ class BaseTrainer:
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights) optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups " LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias") f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias')
return optimizer return optimizer

@ -62,7 +62,7 @@ class BaseValidator:
self.jdict = None self.jdict = None
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f"{self.args.mode}" name = self.args.name or f'{self.args.mode}'
self.save_dir = save_dir or increment_path(Path(project) / name, self.save_dir = save_dir or increment_path(Path(project) / name,
exist_ok=self.args.exist_ok if RANK in {-1, 0} else True) exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
@ -92,7 +92,7 @@ class BaseValidator:
else: else:
callbacks.add_integration_callbacks(self) callbacks.add_integration_callbacks(self)
self.run_callbacks('on_val_start') self.run_callbacks('on_val_start')
assert model is not None, "Either trainer or model is needed for validation" assert model is not None, 'Either trainer or model is needed for validation'
self.device = select_device(self.args.device, self.args.batch) self.device = select_device(self.args.device, self.args.batch)
self.args.half &= self.device.type != 'cpu' self.args.half &= self.device.type != 'cpu'
model = AutoBackend(model, device=self.device, dnn=self.args.dnn, data=self.args.data, fp16=self.args.half) model = AutoBackend(model, device=self.device, dnn=self.args.dnn, data=self.args.data, fp16=self.args.half)
@ -108,7 +108,7 @@ class BaseValidator:
self.logger.info( self.logger.info(
f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models') f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
if isinstance(self.args.data, str) and self.args.data.endswith(".yaml"): if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'):
self.data = check_det_dataset(self.args.data) self.data = check_det_dataset(self.args.data)
elif self.args.task == 'classify': elif self.args.task == 'classify':
self.data = check_cls_dataset(self.args.data) self.data = check_cls_dataset(self.args.data)
@ -142,7 +142,7 @@ class BaseValidator:
# inference # inference
with dt[1]: with dt[1]:
preds = model(batch["img"]) preds = model(batch['img'])
# loss # loss
with dt[2]: with dt[2]:
@ -166,14 +166,14 @@ class BaseValidator:
self.run_callbacks('on_val_end') self.run_callbacks('on_val_end')
if self.training: if self.training:
model.float() model.float()
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")} results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
else: else:
self.logger.info('Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image' % self.logger.info('Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image' %
self.speed) self.speed)
if self.args.save_json and self.jdict: if self.args.save_json and self.jdict:
with open(str(self.save_dir / "predictions.json"), 'w') as f: with open(str(self.save_dir / 'predictions.json'), 'w') as f:
self.logger.info(f"Saving {f.name}...") self.logger.info(f'Saving {f.name}...')
json.dump(self.jdict, f) # flatten and save json.dump(self.jdict, f) # flatten and save
stats = self.eval_json(stats) # update stats stats = self.eval_json(stats) # update stats
return stats return stats
@ -183,7 +183,7 @@ class BaseValidator:
callback(self) callback(self)
def get_dataloader(self, dataset_path, batch_size): def get_dataloader(self, dataset_path, batch_size):
raise NotImplementedError("get_dataloader function not implemented for this validator") raise NotImplementedError('get_dataloader function not implemented for this validator')
def preprocess(self, batch): def preprocess(self, batch):
return batch return batch

@ -27,7 +27,7 @@ from ultralytics import __version__
# Constants # Constants
FILE = Path(__file__).resolve() FILE = Path(__file__).resolve()
ROOT = FILE.parents[2] # YOLO ROOT = FILE.parents[2] # YOLO
DEFAULT_CFG_PATH = ROOT / "yolo/cfg/default.yaml" DEFAULT_CFG_PATH = ROOT / 'yolo/cfg/default.yaml'
RANK = int(os.getenv('RANK', -1)) RANK = int(os.getenv('RANK', -1))
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
@ -111,7 +111,7 @@ class IterableSimpleNamespace(SimpleNamespace):
return iter(vars(self).items()) return iter(vars(self).items())
def __str__(self): def __str__(self):
return '\n'.join(f"{k}={v}" for k, v in vars(self).items()) return '\n'.join(f'{k}={v}' for k, v in vars(self).items())
def __getattr__(self, attr): def __getattr__(self, attr):
name = self.__class__.__name__ name = self.__class__.__name__
@ -288,7 +288,7 @@ def is_pytest_running():
(bool): True if pytest is running, False otherwise. (bool): True if pytest is running, False otherwise.
""" """
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
return "pytest" in sys.modules return 'pytest' in sys.modules
return False return False
@ -336,7 +336,7 @@ def get_git_origin_url():
""" """
if is_git_dir(): if is_git_dir():
with contextlib.suppress(subprocess.CalledProcessError): with contextlib.suppress(subprocess.CalledProcessError):
origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"]) origin = subprocess.check_output(['git', 'config', '--get', 'remote.origin.url'])
return origin.decode().strip() return origin.decode().strip()
return None # if not git dir or on error return None # if not git dir or on error
@ -350,7 +350,7 @@ def get_git_branch():
""" """
if is_git_dir(): if is_git_dir():
with contextlib.suppress(subprocess.CalledProcessError): with contextlib.suppress(subprocess.CalledProcessError):
origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]) origin = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
return origin.decode().strip() return origin.decode().strip()
return None # if not git dir or on error return None # if not git dir or on error
@ -365,9 +365,9 @@ def get_latest_pypi_version(package_name='ultralytics'):
Returns: Returns:
str: The latest version of the package. str: The latest version of the package.
""" """
response = requests.get(f"https://pypi.org/pypi/{package_name}/json") response = requests.get(f'https://pypi.org/pypi/{package_name}/json')
if response.status_code == 200: if response.status_code == 200:
return response.json()["info"]["version"] return response.json()['info']['version']
return None return None
@ -424,28 +424,28 @@ def emojis(string=''):
def colorstr(*input): def colorstr(*input):
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world') # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
*args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
colors = { colors = {
"black": "\033[30m", # basic colors 'black': '\033[30m', # basic colors
"red": "\033[31m", 'red': '\033[31m',
"green": "\033[32m", 'green': '\033[32m',
"yellow": "\033[33m", 'yellow': '\033[33m',
"blue": "\033[34m", 'blue': '\033[34m',
"magenta": "\033[35m", 'magenta': '\033[35m',
"cyan": "\033[36m", 'cyan': '\033[36m',
"white": "\033[37m", 'white': '\033[37m',
"bright_black": "\033[90m", # bright colors 'bright_black': '\033[90m', # bright colors
"bright_red": "\033[91m", 'bright_red': '\033[91m',
"bright_green": "\033[92m", 'bright_green': '\033[92m',
"bright_yellow": "\033[93m", 'bright_yellow': '\033[93m',
"bright_blue": "\033[94m", 'bright_blue': '\033[94m',
"bright_magenta": "\033[95m", 'bright_magenta': '\033[95m',
"bright_cyan": "\033[96m", 'bright_cyan': '\033[96m',
"bright_white": "\033[97m", 'bright_white': '\033[97m',
"end": "\033[0m", # misc 'end': '\033[0m', # misc
"bold": "\033[1m", 'bold': '\033[1m',
"underline": "\033[4m"} 'underline': '\033[4m'}
return "".join(colors[x] for x in args) + f"{string}" + colors["end"] return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
def remove_ansi_codes(string): def remove_ansi_codes(string):
@ -466,21 +466,21 @@ def set_logging(name=LOGGING_NAME, verbose=True):
rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
logging.config.dictConfig({ logging.config.dictConfig({
"version": 1, 'version': 1,
"disable_existing_loggers": False, 'disable_existing_loggers': False,
"formatters": { 'formatters': {
name: { name: {
"format": "%(message)s"}}, 'format': '%(message)s'}},
"handlers": { 'handlers': {
name: { name: {
"class": "logging.StreamHandler", 'class': 'logging.StreamHandler',
"formatter": name, 'formatter': name,
"level": level}}, 'level': level}},
"loggers": { 'loggers': {
name: { name: {
"level": level, 'level': level,
"handlers": [name], 'handlers': [name],
"propagate": False}}}) 'propagate': False}}})
class TryExcept(contextlib.ContextDecorator): class TryExcept(contextlib.ContextDecorator):
@ -521,10 +521,10 @@ def set_sentry():
return None # do not send event return None # do not send event
event['tags'] = { event['tags'] = {
"sys_argv": sys.argv[0], 'sys_argv': sys.argv[0],
"sys_argv_name": Path(sys.argv[0]).name, 'sys_argv_name': Path(sys.argv[0]).name,
"install": 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other', 'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
"os": ENVIRONMENT} 'os': ENVIRONMENT}
return event return event
if SETTINGS['sync'] and \ if SETTINGS['sync'] and \
@ -533,24 +533,24 @@ def set_sentry():
not is_pytest_running() and \ not is_pytest_running() and \
not is_github_actions_ci() and \ not is_github_actions_ci() and \
((is_pip_package() and not is_git_dir()) or ((is_pip_package() and not is_git_dir()) or
(get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git" and get_git_branch() == "main")): (get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git' and get_git_branch() == 'main')):
import hashlib import hashlib
import sentry_sdk # noqa import sentry_sdk # noqa
sentry_sdk.init( sentry_sdk.init(
dsn="https://f805855f03bb4363bc1e16cb7d87b654@o4504521589325824.ingest.sentry.io/4504521592406016", dsn='https://f805855f03bb4363bc1e16cb7d87b654@o4504521589325824.ingest.sentry.io/4504521592406016',
debug=False, debug=False,
traces_sample_rate=1.0, traces_sample_rate=1.0,
release=__version__, release=__version__,
environment='production', # 'dev' or 'production' environment='production', # 'dev' or 'production'
before_send=before_send, before_send=before_send,
ignore_errors=[KeyboardInterrupt, FileNotFoundError]) ignore_errors=[KeyboardInterrupt, FileNotFoundError])
sentry_sdk.set_user({"id": SETTINGS['uuid']}) sentry_sdk.set_user({'id': SETTINGS['uuid']})
# Disable all sentry logging # Disable all sentry logging
for logger in "sentry_sdk", "sentry_sdk.errors": for logger in 'sentry_sdk', 'sentry_sdk.errors':
logging.getLogger(logger).setLevel(logging.CRITICAL) logging.getLogger(logger).setLevel(logging.CRITICAL)
@ -620,7 +620,7 @@ if WINDOWS:
setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging
# Check first-install steps # Check first-install steps
PREFIX = colorstr("Ultralytics: ") PREFIX = colorstr('Ultralytics: ')
SETTINGS = get_settings() SETTINGS = get_settings()
DATASETS_DIR = Path(SETTINGS['datasets_dir']) # global datasets directory DATASETS_DIR = Path(SETTINGS['datasets_dir']) # global datasets directory
ENVIRONMENT = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' if is_jupyter() else \ ENVIRONMENT = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' if is_jupyter() else \

@ -11,7 +11,7 @@ except (ImportError, AssertionError):
clearml = None clearml = None
def _log_images(imgs_dict, group="", step=0): def _log_images(imgs_dict, group='', step=0):
task = Task.current_task() task = Task.current_task()
if task: if task:
for k, v in imgs_dict.items(): for k, v in imgs_dict.items():
@ -20,7 +20,7 @@ def _log_images(imgs_dict, group="", step=0):
def on_pretrain_routine_start(trainer): def on_pretrain_routine_start(trainer):
# TODO: reuse existing task # TODO: reuse existing task
task = Task.init(project_name=trainer.args.project or "YOLOv8", task = Task.init(project_name=trainer.args.project or 'YOLOv8',
task_name=trainer.args.name, task_name=trainer.args.name,
tags=['YOLOv8'], tags=['YOLOv8'],
output_uri=True, output_uri=True,
@ -31,15 +31,15 @@ def on_pretrain_routine_start(trainer):
def on_train_epoch_end(trainer): def on_train_epoch_end(trainer):
if trainer.epoch == 1: if trainer.epoch == 1:
_log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, "Mosaic", trainer.epoch) _log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic', trainer.epoch)
def on_fit_epoch_end(trainer): def on_fit_epoch_end(trainer):
if trainer.epoch == 0: if trainer.epoch == 0:
model_info = { model_info = {
"Parameters": get_num_params(trainer.model), 'Parameters': get_num_params(trainer.model),
"GFLOPs": round(get_flops(trainer.model), 3), 'GFLOPs': round(get_flops(trainer.model), 3),
"Inference speed (ms/img)": round(trainer.validator.speed[1], 3)} 'Inference speed (ms/img)': round(trainer.validator.speed[1], 3)}
Task.current_task().connect(model_info, name='Model') Task.current_task().connect(model_info, name='Model')
@ -50,7 +50,7 @@ def on_train_end(trainer):
callbacks = { callbacks = {
"on_pretrain_routine_start": on_pretrain_routine_start, 'on_pretrain_routine_start': on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end, 'on_train_epoch_end': on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end, 'on_fit_epoch_end': on_fit_epoch_end,
"on_train_end": on_train_end} if clearml else {} 'on_train_end': on_train_end} if clearml else {}

@ -10,13 +10,13 @@ except ImportError:
def on_pretrain_routine_start(trainer): def on_pretrain_routine_start(trainer):
experiment = comet_ml.Experiment(project_name=trainer.args.project or "YOLOv8") experiment = comet_ml.Experiment(project_name=trainer.args.project or 'YOLOv8')
experiment.log_parameters(vars(trainer.args)) experiment.log_parameters(vars(trainer.args))
def on_train_epoch_end(trainer): def on_train_epoch_end(trainer):
experiment = comet_ml.get_global_experiment() experiment = comet_ml.get_global_experiment()
experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1) experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1)
if trainer.epoch == 1: if trainer.epoch == 1:
for f in trainer.save_dir.glob('train_batch*.jpg'): for f in trainer.save_dir.glob('train_batch*.jpg'):
experiment.log_image(f, name=f.stem, step=trainer.epoch + 1) experiment.log_image(f, name=f.stem, step=trainer.epoch + 1)
@ -27,19 +27,19 @@ def on_fit_epoch_end(trainer):
experiment.log_metrics(trainer.metrics, step=trainer.epoch + 1) experiment.log_metrics(trainer.metrics, step=trainer.epoch + 1)
if trainer.epoch == 0: if trainer.epoch == 0:
model_info = { model_info = {
"model/parameters": get_num_params(trainer.model), 'model/parameters': get_num_params(trainer.model),
"model/GFLOPs": round(get_flops(trainer.model), 3), 'model/GFLOPs': round(get_flops(trainer.model), 3),
"model/speed(ms)": round(trainer.validator.speed[1], 3)} 'model/speed(ms)': round(trainer.validator.speed[1], 3)}
experiment.log_metrics(model_info, step=trainer.epoch + 1) experiment.log_metrics(model_info, step=trainer.epoch + 1)
def on_train_end(trainer): def on_train_end(trainer):
experiment = comet_ml.get_global_experiment() experiment = comet_ml.get_global_experiment()
experiment.log_model("YOLOv8", file_or_folder=str(trainer.best), file_name="best.pt", overwrite=True) experiment.log_model('YOLOv8', file_or_folder=str(trainer.best), file_name='best.pt', overwrite=True)
callbacks = { callbacks = {
"on_pretrain_routine_start": on_pretrain_routine_start, 'on_pretrain_routine_start': on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end, 'on_train_epoch_end': on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end, 'on_fit_epoch_end': on_fit_epoch_end,
"on_train_end": on_train_end} if comet_ml else {} 'on_train_end': on_train_end} if comet_ml else {}

@ -11,7 +11,7 @@ def on_pretrain_routine_end(trainer):
session = getattr(trainer, 'hub_session', None) session = getattr(trainer, 'hub_session', None)
if session: if session:
# Start timer for upload rate limit # Start timer for upload rate limit
LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀") LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀')
session.t = {'metrics': time(), 'ckpt': time()} # start timer on self.rate_limit session.t = {'metrics': time(), 'ckpt': time()} # start timer on self.rate_limit
@ -31,7 +31,7 @@ def on_model_save(trainer):
# Upload checkpoints with rate limiting # Upload checkpoints with rate limiting
is_best = trainer.best_fitness == trainer.fitness is_best = trainer.best_fitness == trainer.fitness
if time() - session.t['ckpt'] > session.rate_limits['ckpt']: if time() - session.t['ckpt'] > session.rate_limits['ckpt']:
LOGGER.info(f"{PREFIX}Uploading checkpoint {session.model_id}") LOGGER.info(f'{PREFIX}Uploading checkpoint {session.model_id}')
session.upload_model(trainer.epoch, trainer.last, is_best) session.upload_model(trainer.epoch, trainer.last, is_best)
session.t['ckpt'] = time() # reset timer session.t['ckpt'] = time() # reset timer
@ -40,11 +40,11 @@ def on_train_end(trainer):
session = getattr(trainer, 'hub_session', None) session = getattr(trainer, 'hub_session', None)
if session: if session:
# Upload final model and metrics with exponential standoff # Upload final model and metrics with exponential standoff
LOGGER.info(f"{PREFIX}Training completed successfully ✅\n" LOGGER.info(f'{PREFIX}Training completed successfully ✅\n'
f"{PREFIX}Uploading final {session.model_id}") f'{PREFIX}Uploading final {session.model_id}')
session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics['metrics/mAP50-95(B)'], final=True) session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics['metrics/mAP50-95(B)'], final=True)
session.shutdown() # stop heartbeats session.shutdown() # stop heartbeats
LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀") LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀')
def on_train_start(trainer): def on_train_start(trainer):
@ -64,11 +64,11 @@ def on_export_start(exporter):
callbacks = { callbacks = {
"on_pretrain_routine_end": on_pretrain_routine_end, 'on_pretrain_routine_end': on_pretrain_routine_end,
"on_fit_epoch_end": on_fit_epoch_end, 'on_fit_epoch_end': on_fit_epoch_end,
"on_model_save": on_model_save, 'on_model_save': on_model_save,
"on_train_end": on_train_end, 'on_train_end': on_train_end,
"on_train_start": on_train_start, 'on_train_start': on_train_start,
"on_val_start": on_val_start, 'on_val_start': on_val_start,
"on_predict_start": on_predict_start, 'on_predict_start': on_predict_start,
"on_export_start": on_export_start} 'on_export_start': on_export_start}

@ -16,7 +16,7 @@ def on_pretrain_routine_start(trainer):
def on_batch_end(trainer): def on_batch_end(trainer):
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1) _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
def on_fit_epoch_end(trainer): def on_fit_epoch_end(trainer):
@ -24,6 +24,6 @@ def on_fit_epoch_end(trainer):
callbacks = { callbacks = {
"on_pretrain_routine_start": on_pretrain_routine_start, 'on_pretrain_routine_start': on_pretrain_routine_start,
"on_fit_epoch_end": on_fit_epoch_end, 'on_fit_epoch_end': on_fit_epoch_end,
"on_batch_end": on_batch_end} 'on_batch_end': on_batch_end}

@ -71,7 +71,7 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
msg = "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " \ msg = "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " \
"or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'" "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
if max_dim != 1: if max_dim != 1:
raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}") raise ValueError(f'imgsz={imgsz} is not a valid image size. {msg}')
LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}") LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}")
imgsz = [max(imgsz)] imgsz = [max(imgsz)]
# Make image size a multiple of the stride # Make image size a multiple of the stride
@ -87,9 +87,9 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
return sz return sz
def check_version(current: str = "0.0.0", def check_version(current: str = '0.0.0',
minimum: str = "0.0.0", minimum: str = '0.0.0',
name: str = "version ", name: str = 'version ',
pinned: bool = False, pinned: bool = False,
hard: bool = False, hard: bool = False,
verbose: bool = False) -> bool: verbose: bool = False) -> bool:
@ -109,7 +109,7 @@ def check_version(current: str = "0.0.0",
""" """
current, minimum = (pkg.parse_version(x) for x in (current, minimum)) current, minimum = (pkg.parse_version(x) for x in (current, minimum))
result = (current == minimum) if pinned else (current >= minimum) # bool result = (current == minimum) if pinned else (current >= minimum) # bool
warning_message = f"WARNING ⚠️ {name}{minimum} is required by YOLOv8, but {name}{current} is currently installed" warning_message = f'WARNING ⚠️ {name}{minimum} is required by YOLOv8, but {name}{current} is currently installed'
if hard: if hard:
assert result, emojis(warning_message) # assert min requirements met assert result, emojis(warning_message) # assert min requirements met
if verbose and not result: if verbose and not result:
@ -155,7 +155,7 @@ def check_online() -> bool:
""" """
import socket import socket
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
host = socket.gethostbyname("www.github.com") host = socket.gethostbyname('www.github.com')
socket.create_connection((host, 80), timeout=2) socket.create_connection((host, 80), timeout=2)
return True return True
return False return False
@ -182,7 +182,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
file = None file = None
if isinstance(requirements, Path): # requirements.txt file if isinstance(requirements, Path): # requirements.txt file
file = requirements.resolve() file = requirements.resolve()
assert file.exists(), f"{prefix} {file} not found, check failed." assert file.exists(), f'{prefix} {file} not found, check failed.'
with file.open() as f: with file.open() as f:
requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude] requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
elif isinstance(requirements, str): elif isinstance(requirements, str):
@ -200,7 +200,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
if s and install and AUTOINSTALL: # check environment variable if s and install and AUTOINSTALL: # check environment variable
LOGGER.info(f"{prefix} YOLOv8 requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...") LOGGER.info(f"{prefix} YOLOv8 requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
try: try:
assert check_online(), "AutoUpdate skipped (offline)" assert check_online(), 'AutoUpdate skipped (offline)'
LOGGER.info(subprocess.check_output(f'pip install {s} {cmds}', shell=True).decode()) LOGGER.info(subprocess.check_output(f'pip install {s} {cmds}', shell=True).decode())
s = f"{prefix} {n} package{'s' * (n > 1)} updated per {file or requirements}\n" \ s = f"{prefix} {n} package{'s' * (n > 1)} updated per {file or requirements}\n" \
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n" f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
@ -217,19 +217,19 @@ def check_suffix(file='yolov8n.pt', suffix=('.pt',), msg=''):
for f in file if isinstance(file, (list, tuple)) else [file]: for f in file if isinstance(file, (list, tuple)) else [file]:
s = Path(f).suffix.lower() # file suffix s = Path(f).suffix.lower() # file suffix
if len(s): if len(s):
assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}" assert s in suffix, f'{msg}{f} acceptable suffix is {suffix}'
def check_yolov5u_filename(file: str): def check_yolov5u_filename(file: str):
# Replace legacy YOLOv5 filenames with updated YOLOv5u filenames # Replace legacy YOLOv5 filenames with updated YOLOv5u filenames
if 'yolov3' in file or 'yolov5' in file and 'u' not in file: if 'yolov3' in file or 'yolov5' in file and 'u' not in file:
original_file = file original_file = file
file = re.sub(r"(.*yolov5([nsmlx]))\.", "\\1u.", file) # i.e. yolov5n.pt -> yolov5nu.pt file = re.sub(r'(.*yolov5([nsmlx]))\.', '\\1u.', file) # i.e. yolov5n.pt -> yolov5nu.pt
file = re.sub(r"(.*yolov3(|-tiny|-spp))\.", "\\1u.", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt file = re.sub(r'(.*yolov3(|-tiny|-spp))\.', '\\1u.', file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
if file != original_file: if file != original_file:
LOGGER.info(f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are " LOGGER.info(f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs " f'trained with https://github.com/ultralytics/ultralytics and feature improved performance vs '
f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n") f'standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n')
return file return file
@ -290,7 +290,7 @@ def check_yolo(verbose=True):
# System info # System info
gib = 1 << 30 # bytes per GiB gib = 1 << 30 # bytes per GiB
ram = psutil.virtual_memory().total ram = psutil.virtual_memory().total
total, used, free = shutil.disk_usage("/") total, used, free = shutil.disk_usage('/')
s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)' s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)'
with contextlib.suppress(Exception): # clear display if ipython is installed with contextlib.suppress(Exception): # clear display if ipython is installed
from IPython import display from IPython import display

@ -22,7 +22,7 @@ def find_free_network_port() -> int:
def generate_ddp_file(trainer): def generate_ddp_file(trainer):
import_path = '.'.join(str(trainer.__class__).split(".")[1:-1]) import_path = '.'.join(str(trainer.__class__).split('.')[1:-1])
if not trainer.resume: if not trainer.resume:
shutil.rmtree(trainer.save_dir) # remove the save_dir shutil.rmtree(trainer.save_dir) # remove the save_dir
@ -32,9 +32,9 @@ def generate_ddp_file(trainer):
trainer = {trainer.__class__.__name__}(cfg=cfg) trainer = {trainer.__class__.__name__}(cfg=cfg)
trainer.train()''' trainer.train()'''
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True) (USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
with tempfile.NamedTemporaryFile(prefix="_temp_", with tempfile.NamedTemporaryFile(prefix='_temp_',
suffix=f"{id(trainer)}.py", suffix=f'{id(trainer)}.py',
mode="w+", mode='w+',
encoding='utf-8', encoding='utf-8',
dir=USER_CONFIG_DIR / 'DDP', dir=USER_CONFIG_DIR / 'DDP',
delete=False) as file: delete=False) as file:
@ -47,18 +47,18 @@ def generate_ddp_command(world_size, trainer):
# Get file and args (do not use sys.argv due to security vulnerability) # Get file and args (do not use sys.argv due to security vulnerability)
exclude_args = ['save_dir'] exclude_args = ['save_dir']
args = [f"{k}={v}" for k, v in vars(trainer.args).items() if k not in exclude_args] args = [f'{k}={v}' for k, v in vars(trainer.args).items() if k not in exclude_args]
file = generate_ddp_file(trainer) # if argv[0].endswith('yolo') else os.path.abspath(argv[0]) file = generate_ddp_file(trainer) # if argv[0].endswith('yolo') else os.path.abspath(argv[0])
# Build command # Build command
torch_distributed_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch" torch_distributed_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
cmd = [ cmd = [
sys.executable, "-m", torch_distributed_cmd, "--nproc_per_node", f"{world_size}", "--master_port", sys.executable, '-m', torch_distributed_cmd, '--nproc_per_node', f'{world_size}', '--master_port',
f"{find_free_network_port()}", file] + args f'{find_free_network_port()}', file] + args
return cmd, file return cmd, file
def ddp_cleanup(trainer, file): def ddp_cleanup(trainer, file):
# delete temp file if created # delete temp file if created
if f"{id(trainer)}.py" in file: # if temp_file suffix in file if f'{id(trainer)}.py' in file: # if temp_file suffix in file
os.remove(file) os.remove(file)

@ -95,14 +95,14 @@ def safe_download(url,
torch.hub.download_url_to_file(url, f, progress=progress) torch.hub.download_url_to_file(url, f, progress=progress)
else: else:
from ultralytics.yolo.utils import TQDM_BAR_FORMAT from ultralytics.yolo.utils import TQDM_BAR_FORMAT
with request.urlopen(url) as response, tqdm(total=int(response.getheader("Content-Length", 0)), with request.urlopen(url) as response, tqdm(total=int(response.getheader('Content-Length', 0)),
desc=desc, desc=desc,
disable=not progress, disable=not progress,
unit='B', unit='B',
unit_scale=True, unit_scale=True,
unit_divisor=1024, unit_divisor=1024,
bar_format=TQDM_BAR_FORMAT) as pbar: bar_format=TQDM_BAR_FORMAT) as pbar:
with open(f, "wb") as f_opened: with open(f, 'wb') as f_opened:
for data in response: for data in response:
f_opened.write(data) f_opened.write(data)
pbar.update(len(data)) pbar.update(len(data))
@ -171,7 +171,7 @@ def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
tag, assets = github_assets(repo) # latest release tag, assets = github_assets(repo) # latest release
except Exception: except Exception:
try: try:
tag = subprocess.check_output(["git", "tag"]).decode().split()[-1] tag = subprocess.check_output(['git', 'tag']).decode().split()[-1]
except Exception: except Exception:
tag = release tag = release

@ -24,15 +24,15 @@ to_4tuple = _ntuple(4)
# `xyxy` means left top and right bottom # `xyxy` means left top and right bottom
# `xywh` means center x, center y and width, height(yolo format) # `xywh` means center x, center y and width, height(yolo format)
# `ltwh` means left top and width, height(coco format) # `ltwh` means left top and width, height(coco format)
_formats = ["xyxy", "xywh", "ltwh"] _formats = ['xyxy', 'xywh', 'ltwh']
__all__ = ["Bboxes"] __all__ = ['Bboxes']
class Bboxes: class Bboxes:
"""Now only numpy is supported""" """Now only numpy is supported"""
def __init__(self, bboxes, format="xyxy") -> None: def __init__(self, bboxes, format='xyxy') -> None:
assert format in _formats assert format in _formats
bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
assert bboxes.ndim == 2 assert bboxes.ndim == 2
@ -67,17 +67,17 @@ class Bboxes:
assert format in _formats assert format in _formats
if self.format == format: if self.format == format:
return return
elif self.format == "xyxy": elif self.format == 'xyxy':
bboxes = xyxy2xywh(self.bboxes) if format == "xywh" else xyxy2ltwh(self.bboxes) bboxes = xyxy2xywh(self.bboxes) if format == 'xywh' else xyxy2ltwh(self.bboxes)
elif self.format == "xywh": elif self.format == 'xywh':
bboxes = xywh2xyxy(self.bboxes) if format == "xyxy" else xywh2ltwh(self.bboxes) bboxes = xywh2xyxy(self.bboxes) if format == 'xyxy' else xywh2ltwh(self.bboxes)
else: else:
bboxes = ltwh2xyxy(self.bboxes) if format == "xyxy" else ltwh2xywh(self.bboxes) bboxes = ltwh2xyxy(self.bboxes) if format == 'xyxy' else ltwh2xywh(self.bboxes)
self.bboxes = bboxes self.bboxes = bboxes
self.format = format self.format = format
def areas(self): def areas(self):
self.convert("xyxy") self.convert('xyxy')
return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1]) return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
# def denormalize(self, w, h): # def denormalize(self, w, h):
@ -128,7 +128,7 @@ class Bboxes:
return len(self.bboxes) return len(self.bboxes)
@classmethod @classmethod
def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes": def concatenate(cls, boxes_list: List['Bboxes'], axis=0) -> 'Bboxes':
""" """
Concatenates a list of Boxes into a single Bboxes Concatenates a list of Boxes into a single Bboxes
@ -147,7 +147,7 @@ class Bboxes:
return boxes_list[0] return boxes_list[0]
return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis)) return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
def __getitem__(self, index) -> "Bboxes": def __getitem__(self, index) -> 'Bboxes':
""" """
Args: Args:
index: int, slice, or a BoolArray index: int, slice, or a BoolArray
@ -158,13 +158,13 @@ class Bboxes:
if isinstance(index, int): if isinstance(index, int):
return Bboxes(self.bboxes[index].view(1, -1)) return Bboxes(self.bboxes[index].view(1, -1))
b = self.bboxes[index] b = self.bboxes[index]
assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!" assert b.ndim == 2, f'Indexing on Bboxes with {index} failed to return a matrix!'
return Bboxes(b) return Bboxes(b)
class Instances: class Instances:
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None: def __init__(self, bboxes, segments=None, keypoints=None, bbox_format='xywh', normalized=True) -> None:
""" """
Args: Args:
bboxes (ndarray): bboxes with shape [N, 4]. bboxes (ndarray): bboxes with shape [N, 4].
@ -227,7 +227,7 @@ class Instances:
def add_padding(self, padw, padh): def add_padding(self, padw, padh):
# handle rect and mosaic situation # handle rect and mosaic situation
assert not self.normalized, "you should add padding with absolute coordinates." assert not self.normalized, 'you should add padding with absolute coordinates.'
self._bboxes.add(offset=(padw, padh, padw, padh)) self._bboxes.add(offset=(padw, padh, padw, padh))
self.segments[..., 0] += padw self.segments[..., 0] += padw
self.segments[..., 1] += padh self.segments[..., 1] += padh
@ -235,7 +235,7 @@ class Instances:
self.keypoints[..., 0] += padw self.keypoints[..., 0] += padw
self.keypoints[..., 1] += padh self.keypoints[..., 1] += padh
def __getitem__(self, index) -> "Instances": def __getitem__(self, index) -> 'Instances':
""" """
Args: Args:
index: int, slice, or a BoolArray index: int, slice, or a BoolArray
@ -256,7 +256,7 @@ class Instances:
) )
def flipud(self, h): def flipud(self, h):
if self._bboxes.format == "xyxy": if self._bboxes.format == 'xyxy':
y1 = self.bboxes[:, 1].copy() y1 = self.bboxes[:, 1].copy()
y2 = self.bboxes[:, 3].copy() y2 = self.bboxes[:, 3].copy()
self.bboxes[:, 1] = h - y2 self.bboxes[:, 1] = h - y2
@ -268,7 +268,7 @@ class Instances:
self.keypoints[..., 1] = h - self.keypoints[..., 1] self.keypoints[..., 1] = h - self.keypoints[..., 1]
def fliplr(self, w): def fliplr(self, w):
if self._bboxes.format == "xyxy": if self._bboxes.format == 'xyxy':
x1 = self.bboxes[:, 0].copy() x1 = self.bboxes[:, 0].copy()
x2 = self.bboxes[:, 2].copy() x2 = self.bboxes[:, 2].copy()
self.bboxes[:, 0] = w - x2 self.bboxes[:, 0] = w - x2
@ -281,10 +281,10 @@ class Instances:
def clip(self, w, h): def clip(self, w, h):
ori_format = self._bboxes.format ori_format = self._bboxes.format
self.convert_bbox(format="xyxy") self.convert_bbox(format='xyxy')
self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w) self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h) self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
if ori_format != "xyxy": if ori_format != 'xyxy':
self.convert_bbox(format=ori_format) self.convert_bbox(format=ori_format)
self.segments[..., 0] = self.segments[..., 0].clip(0, w) self.segments[..., 0] = self.segments[..., 0].clip(0, w)
self.segments[..., 1] = self.segments[..., 1].clip(0, h) self.segments[..., 1] = self.segments[..., 1].clip(0, h)
@ -304,7 +304,7 @@ class Instances:
return len(self.bboxes) return len(self.bboxes)
@classmethod @classmethod
def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances": def concatenate(cls, instances_list: List['Instances'], axis=0) -> 'Instances':
""" """
Concatenates a list of Boxes into a single Bboxes Concatenates a list of Boxes into a single Bboxes

@ -16,7 +16,7 @@ class VarifocalLoss(nn.Module):
def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0): def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0):
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
weight).sum() weight).sum()
return loss return loss
@ -52,5 +52,5 @@ class BboxLoss(nn.Module):
tr = tl + 1 # target right tr = tl + 1 # target right
wl = tr - target # weight left wl = tr - target # weight left
wr = 1 - wl # weight right wr = 1 - wl # weight right
return (F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl + return (F.cross_entropy(pred_dist, tl.view(-1), reduction='none').view(tl.shape) * wl +
F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr).mean(-1, keepdim=True) F.cross_entropy(pred_dist, tr.view(-1), reduction='none').view(tl.shape) * wr).mean(-1, keepdim=True)

@ -238,14 +238,14 @@ class ConfusionMatrix:
nc, nn = self.nc, len(names) # number of classes, names nc, nn = self.nc, len(names) # number of classes, names
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
ticklabels = (names + ['background']) if labels else "auto" ticklabels = (names + ['background']) if labels else 'auto'
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
sn.heatmap(array, sn.heatmap(array,
ax=ax, ax=ax,
annot=nc < 30, annot=nc < 30,
annot_kws={ annot_kws={
"size": 8}, 'size': 8},
cmap='Blues', cmap='Blues',
fmt='.2f', fmt='.2f',
square=True, square=True,
@ -287,7 +287,7 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
ax.set_ylabel('Precision') ax.set_ylabel('Precision')
ax.set_xlim(0, 1) ax.set_xlim(0, 1)
ax.set_ylim(0, 1) ax.set_ylim(0, 1)
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left") ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left')
ax.set_title('Precision-Recall Curve') ax.set_title('Precision-Recall Curve')
fig.savefig(save_dir, dpi=250) fig.savefig(save_dir, dpi=250)
plt.close(fig) plt.close(fig)
@ -309,7 +309,7 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
ax.set_ylabel(ylabel) ax.set_ylabel(ylabel)
ax.set_xlim(0, 1) ax.set_xlim(0, 1)
ax.set_ylim(0, 1) ax.set_ylim(0, 1)
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left") ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left')
ax.set_title(f'{ylabel}-Confidence Curve') ax.set_title(f'{ylabel}-Confidence Curve')
fig.savefig(save_dir, dpi=250) fig.savefig(save_dir, dpi=250)
plt.close(fig) plt.close(fig)
@ -343,7 +343,7 @@ def compute_ap(recall, precision):
return ap, mpre, mrec return ap, mpre, mrec
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), names=(), eps=1e-16, prefix=""): def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), names=(), eps=1e-16, prefix=''):
""" Compute the average precision, given the recall and precision curves. """ Compute the average precision, given the recall and precision curves.
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
# Arguments # Arguments
@ -507,7 +507,7 @@ class Metric:
class DetMetrics: class DetMetrics:
def __init__(self, save_dir=Path("."), plot=False, names=()) -> None: def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
self.save_dir = save_dir self.save_dir = save_dir
self.plot = plot self.plot = plot
self.names = names self.names = names
@ -521,7 +521,7 @@ class DetMetrics:
@property @property
def keys(self): def keys(self):
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"] return ['metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)']
def mean_results(self): def mean_results(self):
return self.box.mean_results() return self.box.mean_results()
@ -543,12 +543,12 @@ class DetMetrics:
@property @property
def results_dict(self): def results_dict(self):
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness])) return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
class SegmentMetrics: class SegmentMetrics:
def __init__(self, save_dir=Path("."), plot=False, names=()) -> None: def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
self.save_dir = save_dir self.save_dir = save_dir
self.plot = plot self.plot = plot
self.names = names self.names = names
@ -563,7 +563,7 @@ class SegmentMetrics:
plot=self.plot, plot=self.plot,
save_dir=self.save_dir, save_dir=self.save_dir,
names=self.names, names=self.names,
prefix="Mask")[2:] prefix='Mask')[2:]
self.seg.nc = len(self.names) self.seg.nc = len(self.names)
self.seg.update(results_mask) self.seg.update(results_mask)
results_box = ap_per_class(tp_b, results_box = ap_per_class(tp_b,
@ -573,15 +573,15 @@ class SegmentMetrics:
plot=self.plot, plot=self.plot,
save_dir=self.save_dir, save_dir=self.save_dir,
names=self.names, names=self.names,
prefix="Box")[2:] prefix='Box')[2:]
self.box.nc = len(self.names) self.box.nc = len(self.names)
self.box.update(results_box) self.box.update(results_box)
@property @property
def keys(self): def keys(self):
return [ return [
"metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)", 'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)',
"metrics/precision(M)", "metrics/recall(M)", "metrics/mAP50(M)", "metrics/mAP50-95(M)"] 'metrics/precision(M)', 'metrics/recall(M)', 'metrics/mAP50(M)', 'metrics/mAP50-95(M)']
def mean_results(self): def mean_results(self):
return self.box.mean_results() + self.seg.mean_results() return self.box.mean_results() + self.seg.mean_results()
@ -604,7 +604,7 @@ class SegmentMetrics:
@property @property
def results_dict(self): def results_dict(self):
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness])) return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
class ClassifyMetrics: class ClassifyMetrics:
@ -626,8 +626,8 @@ class ClassifyMetrics:
@property @property
def results_dict(self): def results_dict(self):
return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness])) return dict(zip(self.keys + ['fitness'], [self.top1, self.top5, self.fitness]))
@property @property
def keys(self): def keys(self):
return ["metrics/accuracy_top1", "metrics/accuracy_top5"] return ['metrics/accuracy_top1', 'metrics/accuracy_top5']

@ -715,4 +715,4 @@ def clean_str(s):
Returns: Returns:
(str): a string with special characters replaced by an underscore _ (str): a string with special characters replaced by an underscore _
""" """
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s) return re.sub(pattern='[|@#!¡·$€%&()=?¿^*;:,¨´><+]', repl='_', string=s)

@ -61,7 +61,7 @@ def DDP_model(model):
def select_device(device='', batch=0, newline=False): def select_device(device='', batch=0, newline=False):
# device = None or 'cpu' or 0 or '0' or '0,1,2,3' # device = None or 'cpu' or 0 or '0' or '0,1,2,3'
s = f"Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} " s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
device = str(device).lower() device = str(device).lower()
for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ': for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1' device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
@ -74,15 +74,15 @@ def select_device(device='', batch=0, newline=False):
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available() os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))): if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))):
LOGGER.info(s) LOGGER.info(s)
install = "See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no " \ install = 'See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no ' \
"CUDA devices are seen by torch.\n" if torch.cuda.device_count() == 0 else "" 'CUDA devices are seen by torch.\n' if torch.cuda.device_count() == 0 else ''
raise ValueError(f"Invalid CUDA 'device={device}' requested." raise ValueError(f"Invalid CUDA 'device={device}' requested."
f" Use 'device=cpu' or pass valid CUDA device(s) if available," f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n" f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}" f'\ntorch.cuda.is_available(): {torch.cuda.is_available()}'
f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}" f'\ntorch.cuda.device_count(): {torch.cuda.device_count()}'
f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n" f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
f"{install}") f'{install}')
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7 devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
@ -177,7 +177,7 @@ def model_info(model, verbose=False, imgsz=640):
fused = ' (fused)' if model.is_fused() else '' fused = ' (fused)' if model.is_fused() else ''
fs = f', {flops:.1f} GFLOPs' if flops else '' fs = f', {flops:.1f} GFLOPs' if flops else ''
m = Path(getattr(model, 'yaml_file', '') or model.yaml.get('yaml_file', '')).stem.replace('yolo', 'YOLO') or 'Model' m = Path(getattr(model, 'yaml_file', '') or model.yaml.get('yaml_file', '')).stem.replace('yolo', 'YOLO') or 'Model'
LOGGER.info(f"{m} summary{fused}: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}") LOGGER.info(f'{m} summary{fused}: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}')
def get_num_params(model): def get_num_params(model):

@ -2,4 +2,4 @@
from ultralytics.yolo.v8 import classify, detect, segment from ultralytics.yolo.v8 import classify, detect, segment
__all__ = ["classify", "segment", "detect"] __all__ = ['classify', 'segment', 'detect']

@ -4,4 +4,4 @@ from ultralytics.yolo.v8.classify.predict import ClassificationPredictor, predic
from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train
from ultralytics.yolo.v8.classify.val import ClassificationValidator, val from ultralytics.yolo.v8.classify.val import ClassificationValidator, val
__all__ = ["ClassificationPredictor", "predict", "ClassificationTrainer", "train", "ClassificationValidator", "val"] __all__ = ['ClassificationPredictor', 'predict', 'ClassificationTrainer', 'train', 'ClassificationValidator', 'val']

@ -28,7 +28,7 @@ class ClassificationPredictor(BasePredictor):
def write_results(self, idx, results, batch): def write_results(self, idx, results, batch):
p, im, im0 = batch p, im, im0 = batch
log_string = "" log_string = ''
if len(im.shape) == 3: if len(im.shape) == 3:
im = im[None] # expand for batch dim im = im[None] # expand for batch dim
self.seen += 1 self.seen += 1
@ -65,9 +65,9 @@ class ClassificationPredictor(BasePredictor):
def predict(cfg=DEFAULT_CFG, use_python=False): def predict(cfg=DEFAULT_CFG, use_python=False):
model = cfg.model or "yolov8n-cls.pt" # or "resnet18" model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
source = cfg.source if cfg.source is not None else ROOT / "assets" if (ROOT / "assets").exists() \ source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
else "https://ultralytics.com/images/bus.jpg" else 'https://ultralytics.com/images/bus.jpg'
args = dict(model=model, source=source) args = dict(model=model, source=source)
if use_python: if use_python:
@ -78,5 +78,5 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
predictor.predict_cli() predictor.predict_cli()
if __name__ == "__main__": if __name__ == '__main__':
predict() predict()

@ -16,14 +16,14 @@ class ClassificationTrainer(BaseTrainer):
def __init__(self, cfg=DEFAULT_CFG, overrides=None): def __init__(self, cfg=DEFAULT_CFG, overrides=None):
if overrides is None: if overrides is None:
overrides = {} overrides = {}
overrides["task"] = "classify" overrides['task'] = 'classify'
super().__init__(cfg, overrides) super().__init__(cfg, overrides)
def set_model_attributes(self): def set_model_attributes(self):
self.model.names = self.data["names"] self.model.names = self.data['names']
def get_model(self, cfg=None, weights=None, verbose=True): def get_model(self, cfg=None, weights=None, verbose=True):
model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
if weights: if weights:
model.load(weights) model.load(weights)
@ -53,11 +53,11 @@ class ClassificationTrainer(BaseTrainer):
model = str(self.model) model = str(self.model)
# Load a YOLO model locally, from torchvision, or from Ultralytics assets # Load a YOLO model locally, from torchvision, or from Ultralytics assets
if model.endswith(".pt"): if model.endswith('.pt'):
self.model, _ = attempt_load_one_weight(model, device='cpu') self.model, _ = attempt_load_one_weight(model, device='cpu')
for p in self.model.parameters(): for p in self.model.parameters():
p.requires_grad = True # for training p.requires_grad = True # for training
elif model.endswith(".yaml"): elif model.endswith('.yaml'):
self.model = self.get_model(cfg=model) self.model = self.get_model(cfg=model)
elif model in torchvision.models.__dict__: elif model in torchvision.models.__dict__:
pretrained = True pretrained = True
@ -67,15 +67,15 @@ class ClassificationTrainer(BaseTrainer):
return # dont return ckpt. Classification doesn't support resume return # dont return ckpt. Classification doesn't support resume
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
loader = build_classification_dataloader(path=dataset_path, loader = build_classification_dataloader(path=dataset_path,
imgsz=self.args.imgsz, imgsz=self.args.imgsz,
batch_size=batch_size if mode == "train" else (batch_size * 2), batch_size=batch_size if mode == 'train' else (batch_size * 2),
augment=mode == "train", augment=mode == 'train',
rank=rank, rank=rank,
workers=self.args.workers) workers=self.args.workers)
# Attach inference transforms # Attach inference transforms
if mode != "train": if mode != 'train':
if is_parallel(self.model): if is_parallel(self.model):
self.model.module.transforms = loader.dataset.torch_transforms self.model.module.transforms = loader.dataset.torch_transforms
else: else:
@ -83,8 +83,8 @@ class ClassificationTrainer(BaseTrainer):
return loader return loader
def preprocess_batch(self, batch): def preprocess_batch(self, batch):
batch["img"] = batch["img"].to(self.device) batch['img'] = batch['img'].to(self.device)
batch["cls"] = batch["cls"].to(self.device) batch['cls'] = batch['cls'].to(self.device)
return batch return batch
def progress_string(self): def progress_string(self):
@ -96,7 +96,7 @@ class ClassificationTrainer(BaseTrainer):
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir, logger=self.console) return v8.classify.ClassificationValidator(self.test_loader, self.save_dir, logger=self.console)
def criterion(self, preds, batch): def criterion(self, preds, batch):
loss = torch.nn.functional.cross_entropy(preds, batch["cls"], reduction='sum') / self.args.nbs loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs
loss_items = loss.detach() loss_items = loss.detach()
return loss, loss_items return loss, loss_items
@ -112,12 +112,12 @@ class ClassificationTrainer(BaseTrainer):
# else: # else:
# return keys # return keys
def label_loss_items(self, loss_items=None, prefix="train"): def label_loss_items(self, loss_items=None, prefix='train'):
""" """
Returns a loss dict with labelled training loss items tensor Returns a loss dict with labelled training loss items tensor
""" """
# Not needed for classification but necessary for segmentation & detection # Not needed for classification but necessary for segmentation & detection
keys = [f"{prefix}/{x}" for x in self.loss_names] keys = [f'{prefix}/{x}' for x in self.loss_names]
if loss_items is None: if loss_items is None:
return keys return keys
loss_items = [round(float(loss_items), 5)] loss_items = [round(float(loss_items), 5)]
@ -140,8 +140,8 @@ class ClassificationTrainer(BaseTrainer):
def train(cfg=DEFAULT_CFG, use_python=False): def train(cfg=DEFAULT_CFG, use_python=False):
model = cfg.model or "yolov8n-cls.pt" # or "resnet18" model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
data = cfg.data or "mnist160" # or yolo.ClassificationDataset("mnist") data = cfg.data or 'mnist160' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else '' device = cfg.device if cfg.device is not None else ''
args = dict(model=model, data=data, device=device) args = dict(model=model, data=data, device=device)
@ -153,5 +153,5 @@ def train(cfg=DEFAULT_CFG, use_python=False):
trainer.train() trainer.train()
if __name__ == "__main__": if __name__ == '__main__':
train() train()

@ -21,14 +21,14 @@ class ClassificationValidator(BaseValidator):
self.targets = [] self.targets = []
def preprocess(self, batch): def preprocess(self, batch):
batch["img"] = batch["img"].to(self.device, non_blocking=True) batch['img'] = batch['img'].to(self.device, non_blocking=True)
batch["img"] = batch["img"].half() if self.args.half else batch["img"].float() batch['img'] = batch['img'].half() if self.args.half else batch['img'].float()
batch["cls"] = batch["cls"].to(self.device) batch['cls'] = batch['cls'].to(self.device)
return batch return batch
def update_metrics(self, preds, batch): def update_metrics(self, preds, batch):
self.pred.append(preds.argsort(1, descending=True)[:, :5]) self.pred.append(preds.argsort(1, descending=True)[:, :5])
self.targets.append(batch["cls"]) self.targets.append(batch['cls'])
def get_stats(self): def get_stats(self):
self.metrics.process(self.targets, self.pred) self.metrics.process(self.targets, self.pred)
@ -42,12 +42,12 @@ class ClassificationValidator(BaseValidator):
def print_results(self): def print_results(self):
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
self.logger.info(pf % ("all", self.metrics.top1, self.metrics.top5)) self.logger.info(pf % ('all', self.metrics.top1, self.metrics.top5))
def val(cfg=DEFAULT_CFG, use_python=False): def val(cfg=DEFAULT_CFG, use_python=False):
model = cfg.model or "yolov8n-cls.pt" # or "resnet18" model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
data = cfg.data or "mnist160" data = cfg.data or 'mnist160'
args = dict(model=model, data=data) args = dict(model=model, data=data)
if use_python: if use_python:
@ -58,5 +58,5 @@ def val(cfg=DEFAULT_CFG, use_python=False):
validator(model=args['model']) validator(model=args['model'])
if __name__ == "__main__": if __name__ == '__main__':
val() val()

@ -4,4 +4,4 @@ from .predict import DetectionPredictor, predict
from .train import DetectionTrainer, train from .train import DetectionTrainer, train
from .val import DetectionValidator, val from .val import DetectionValidator, val
__all__ = ["DetectionPredictor", "predict", "DetectionTrainer", "train", "DetectionValidator", "val"] __all__ = ['DetectionPredictor', 'predict', 'DetectionTrainer', 'train', 'DetectionValidator', 'val']

@ -37,7 +37,7 @@ class DetectionPredictor(BasePredictor):
def write_results(self, idx, results, batch): def write_results(self, idx, results, batch):
p, im, im0 = batch p, im, im0 = batch
log_string = "" log_string = ''
if len(im.shape) == 3: if len(im.shape) == 3:
im = im[None] # expand for batch dim im = im[None] # expand for batch dim
self.seen += 1 self.seen += 1
@ -69,7 +69,7 @@ class DetectionPredictor(BasePredictor):
f.write(('%g ' * len(line)).rstrip() % line + '\n') f.write(('%g ' * len(line)).rstrip() % line + '\n')
if self.args.save or self.args.save_crop or self.args.show: # Add bbox to image if self.args.save or self.args.save_crop or self.args.show: # Add bbox to image
c = int(cls) # integer class c = int(cls) # integer class
name = f"id:{int(d.id.item())} {self.model.names[c]}" if d.id is not None else self.model.names[c] name = f'id:{int(d.id.item())} {self.model.names[c]}' if d.id is not None else self.model.names[c]
label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}') label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}')
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
if self.args.save_crop: if self.args.save_crop:
@ -82,9 +82,9 @@ class DetectionPredictor(BasePredictor):
def predict(cfg=DEFAULT_CFG, use_python=False): def predict(cfg=DEFAULT_CFG, use_python=False):
model = cfg.model or "yolov8n.pt" model = cfg.model or 'yolov8n.pt'
source = cfg.source if cfg.source is not None else ROOT / "assets" if (ROOT / "assets").exists() \ source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
else "https://ultralytics.com/images/bus.jpg" else 'https://ultralytics.com/images/bus.jpg'
args = dict(model=model, source=source) args = dict(model=model, source=source)
if use_python: if use_python:
@ -95,5 +95,5 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
predictor.predict_cli() predictor.predict_cli()
if __name__ == "__main__": if __name__ == '__main__':
predict() predict()

@ -20,7 +20,7 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
# BaseTrainer python usage # BaseTrainer python usage
class DetectionTrainer(BaseTrainer): class DetectionTrainer(BaseTrainer):
def get_dataloader(self, dataset_path, batch_size, mode="train", rank=0): def get_dataloader(self, dataset_path, batch_size, mode='train', rank=0):
# TODO: manage splits differently # TODO: manage splits differently
# calculate stride - check if model is initialized # calculate stride - check if model is initialized
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
@ -29,21 +29,21 @@ class DetectionTrainer(BaseTrainer):
batch_size=batch_size, batch_size=batch_size,
stride=gs, stride=gs,
hyp=vars(self.args), hyp=vars(self.args),
augment=mode == "train", augment=mode == 'train',
cache=self.args.cache, cache=self.args.cache,
pad=0 if mode == "train" else 0.5, pad=0 if mode == 'train' else 0.5,
rect=self.args.rect or mode == "val", rect=self.args.rect or mode == 'val',
rank=rank, rank=rank,
workers=self.args.workers, workers=self.args.workers,
close_mosaic=self.args.close_mosaic != 0, close_mosaic=self.args.close_mosaic != 0,
prefix=colorstr(f'{mode}: '), prefix=colorstr(f'{mode}: '),
shuffle=mode == "train", shuffle=mode == 'train',
seed=self.args.seed)[0] if self.args.v5loader else \ seed=self.args.seed)[0] if self.args.v5loader else \
build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode, build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode,
rect=mode == "val", names=self.data['names'])[0] rect=mode == 'val', names=self.data['names'])[0]
def preprocess_batch(self, batch): def preprocess_batch(self, batch):
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255 batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255
return batch return batch
def set_model_attributes(self): def set_model_attributes(self):
@ -51,13 +51,13 @@ class DetectionTrainer(BaseTrainer):
# self.args.box *= 3 / nl # scale to layers # self.args.box *= 3 / nl # scale to layers
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers # self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
self.model.nc = self.data["nc"] # attach number of classes to model self.model.nc = self.data['nc'] # attach number of classes to model
self.model.names = self.data["names"] # attach class names to model self.model.names = self.data['names'] # attach class names to model
self.model.args = self.args # attach hyperparameters to model self.model.args = self.args # attach hyperparameters to model
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
def get_model(self, cfg=None, weights=None, verbose=True): def get_model(self, cfg=None, weights=None, verbose=True):
model = DetectionModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1) model = DetectionModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
if weights: if weights:
model.load(weights) model.load(weights)
@ -75,12 +75,12 @@ class DetectionTrainer(BaseTrainer):
self.compute_loss = Loss(de_parallel(self.model)) self.compute_loss = Loss(de_parallel(self.model))
return self.compute_loss(preds, batch) return self.compute_loss(preds, batch)
def label_loss_items(self, loss_items=None, prefix="train"): def label_loss_items(self, loss_items=None, prefix='train'):
""" """
Returns a loss dict with labelled training loss items tensor Returns a loss dict with labelled training loss items tensor
""" """
# Not needed for classification but necessary for segmentation & detection # Not needed for classification but necessary for segmentation & detection
keys = [f"{prefix}/{x}" for x in self.loss_names] keys = [f'{prefix}/{x}' for x in self.loss_names]
if loss_items is not None: if loss_items is not None:
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
return dict(zip(keys, loss_items)) return dict(zip(keys, loss_items))
@ -92,12 +92,12 @@ class DetectionTrainer(BaseTrainer):
(4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size') (4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
def plot_training_samples(self, batch, ni): def plot_training_samples(self, batch, ni):
plot_images(images=batch["img"], plot_images(images=batch['img'],
batch_idx=batch["batch_idx"], batch_idx=batch['batch_idx'],
cls=batch["cls"].squeeze(-1), cls=batch['cls'].squeeze(-1),
bboxes=batch["bboxes"], bboxes=batch['bboxes'],
paths=batch["im_file"], paths=batch['im_file'],
fname=self.save_dir / f"train_batch{ni}.jpg") fname=self.save_dir / f'train_batch{ni}.jpg')
def plot_metrics(self): def plot_metrics(self):
plot_results(file=self.csv) # save results.png plot_results(file=self.csv) # save results.png
@ -169,7 +169,7 @@ class Loss:
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# targets # targets
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) targets = torch.cat((batch['batch_idx'].view(-1, 1), batch['cls'].view(-1, 1), batch['bboxes']), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
@ -201,8 +201,8 @@ class Loss:
def train(cfg=DEFAULT_CFG, use_python=False): def train(cfg=DEFAULT_CFG, use_python=False):
model = cfg.model or "yolov8n.pt" model = cfg.model or 'yolov8n.pt'
data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist") data = cfg.data or 'coco128.yaml' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else '' device = cfg.device if cfg.device is not None else ''
args = dict(model=model, data=data, device=device) args = dict(model=model, data=data, device=device)
@ -214,5 +214,5 @@ def train(cfg=DEFAULT_CFG, use_python=False):
trainer.train() trainer.train()
if __name__ == "__main__": if __name__ == '__main__':
train() train()

@ -28,13 +28,13 @@ class DetectionValidator(BaseValidator):
self.niou = self.iouv.numel() self.niou = self.iouv.numel()
def preprocess(self, batch): def preprocess(self, batch):
batch["img"] = batch["img"].to(self.device, non_blocking=True) batch['img'] = batch['img'].to(self.device, non_blocking=True)
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255 batch['img'] = (batch['img'].half() if self.args.half else batch['img'].float()) / 255
for k in ["batch_idx", "cls", "bboxes"]: for k in ['batch_idx', 'cls', 'bboxes']:
batch[k] = batch[k].to(self.device) batch[k] = batch[k].to(self.device)
nb = len(batch["img"]) nb = len(batch['img'])
self.lb = [torch.cat([batch["cls"], batch["bboxes"]], dim=-1)[batch["batch_idx"] == i] self.lb = [torch.cat([batch['cls'], batch['bboxes']], dim=-1)[batch['batch_idx'] == i]
for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling
return batch return batch
@ -54,7 +54,7 @@ class DetectionValidator(BaseValidator):
self.stats = [] self.stats = []
def get_desc(self): def get_desc(self):
return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)") return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)')
def postprocess(self, preds): def postprocess(self, preds):
preds = ops.non_max_suppression(preds, preds = ops.non_max_suppression(preds,
@ -69,11 +69,11 @@ class DetectionValidator(BaseValidator):
def update_metrics(self, preds, batch): def update_metrics(self, preds, batch):
# Metrics # Metrics
for si, pred in enumerate(preds): for si, pred in enumerate(preds):
idx = batch["batch_idx"] == si idx = batch['batch_idx'] == si
cls = batch["cls"][idx] cls = batch['cls'][idx]
bbox = batch["bboxes"][idx] bbox = batch['bboxes'][idx]
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
shape = batch["ori_shape"][si] shape = batch['ori_shape'][si]
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
self.seen += 1 self.seen += 1
@ -88,16 +88,16 @@ class DetectionValidator(BaseValidator):
if self.args.single_cls: if self.args.single_cls:
pred[:, 5] = 0 pred[:, 5] = 0
predn = pred.clone() predn = pred.clone()
ops.scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape, ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
ratio_pad=batch["ratio_pad"][si]) # native-space pred ratio_pad=batch['ratio_pad'][si]) # native-space pred
# Evaluate # Evaluate
if nl: if nl:
height, width = batch["img"].shape[2:] height, width = batch['img'].shape[2:]
tbox = ops.xywh2xyxy(bbox) * torch.tensor( tbox = ops.xywh2xyxy(bbox) * torch.tensor(
(width, height, width, height), device=self.device) # target boxes (width, height, width, height), device=self.device) # target boxes
ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape, ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
ratio_pad=batch["ratio_pad"][si]) # native-space labels ratio_pad=batch['ratio_pad'][si]) # native-space labels
labelsn = torch.cat((cls, tbox), 1) # native-space labels labelsn = torch.cat((cls, tbox), 1) # native-space labels
correct_bboxes = self._process_batch(predn, labelsn) correct_bboxes = self._process_batch(predn, labelsn)
# TODO: maybe remove these `self.` arguments as they already are member variable # TODO: maybe remove these `self.` arguments as they already are member variable
@ -107,7 +107,7 @@ class DetectionValidator(BaseValidator):
# Save # Save
if self.args.save_json: if self.args.save_json:
self.pred_to_json(predn, batch["im_file"][si]) self.pred_to_json(predn, batch['im_file'][si])
# if self.args.save_txt: # if self.args.save_txt:
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
@ -120,7 +120,7 @@ class DetectionValidator(BaseValidator):
def print_results(self): def print_results(self):
pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format
self.logger.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results())) self.logger.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
if self.nt_per_class.sum() == 0: if self.nt_per_class.sum() == 0:
self.logger.warning( self.logger.warning(
f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels') f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels')
@ -175,21 +175,21 @@ class DetectionValidator(BaseValidator):
shuffle=False, shuffle=False,
seed=self.args.seed)[0] if self.args.v5loader else \ seed=self.args.seed)[0] if self.args.v5loader else \
build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, names=self.data['names'], build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, names=self.data['names'],
mode="val")[0] mode='val')[0]
def plot_val_samples(self, batch, ni): def plot_val_samples(self, batch, ni):
plot_images(batch["img"], plot_images(batch['img'],
batch["batch_idx"], batch['batch_idx'],
batch["cls"].squeeze(-1), batch['cls'].squeeze(-1),
batch["bboxes"], batch['bboxes'],
paths=batch["im_file"], paths=batch['im_file'],
fname=self.save_dir / f"val_batch{ni}_labels.jpg", fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names) names=self.names)
def plot_predictions(self, batch, preds, ni): def plot_predictions(self, batch, preds, ni):
plot_images(batch["img"], plot_images(batch['img'],
*output_to_target(preds, max_det=15), *output_to_target(preds, max_det=15),
paths=batch["im_file"], paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg', fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names) # pred names=self.names) # pred
@ -207,8 +207,8 @@ class DetectionValidator(BaseValidator):
def eval_json(self, stats): def eval_json(self, stats):
if self.args.save_json and self.is_coco and len(self.jdict): if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / "annotations/instances_val2017.json" # annotations anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
pred_json = self.save_dir / "predictions.json" # predictions pred_json = self.save_dir / 'predictions.json' # predictions
self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...') self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements('pycocotools>=2.0.6') check_requirements('pycocotools>=2.0.6')
@ -216,7 +216,7 @@ class DetectionValidator(BaseValidator):
from pycocotools.cocoeval import COCOeval # noqa from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json: for x in anno_json, pred_json:
assert x.is_file(), f"{x} file not found" assert x.is_file(), f'{x} file not found'
anno = COCO(str(anno_json)) # init annotations api anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
eval = COCOeval(anno, pred, 'bbox') eval = COCOeval(anno, pred, 'bbox')
@ -232,8 +232,8 @@ class DetectionValidator(BaseValidator):
def val(cfg=DEFAULT_CFG, use_python=False): def val(cfg=DEFAULT_CFG, use_python=False):
model = cfg.model or "yolov8n.pt" model = cfg.model or 'yolov8n.pt'
data = cfg.data or "coco128.yaml" data = cfg.data or 'coco128.yaml'
args = dict(model=model, data=data) args = dict(model=model, data=data)
if use_python: if use_python:
@ -244,5 +244,5 @@ def val(cfg=DEFAULT_CFG, use_python=False):
validator(model=args['model']) validator(model=args['model'])
if __name__ == "__main__": if __name__ == '__main__':
val() val()

@ -4,4 +4,4 @@ from .predict import SegmentationPredictor, predict
from .train import SegmentationTrainer, train from .train import SegmentationTrainer, train
from .val import SegmentationValidator, val from .val import SegmentationValidator, val
__all__ = ["SegmentationPredictor", "predict", "SegmentationTrainer", "train", "SegmentationValidator", "val"] __all__ = ['SegmentationPredictor', 'predict', 'SegmentationTrainer', 'train', 'SegmentationValidator', 'val']

@ -39,7 +39,7 @@ class SegmentationPredictor(DetectionPredictor):
def write_results(self, idx, results, batch): def write_results(self, idx, results, batch):
p, im, im0 = batch p, im, im0 = batch
log_string = "" log_string = ''
if len(im.shape) == 3: if len(im.shape) == 3:
im = im[None] # expand for batch dim im = im[None] # expand for batch dim
self.seen += 1 self.seen += 1
@ -84,7 +84,7 @@ class SegmentationPredictor(DetectionPredictor):
if self.args.save or self.args.save_crop or self.args.show: # Add bbox to image if self.args.save or self.args.save_crop or self.args.show: # Add bbox to image
c = int(cls) # integer class c = int(cls) # integer class
name = f"id:{int(d.id.item())} {self.model.names[c]}" if d.id is not None else self.model.names[c] name = f'id:{int(d.id.item())} {self.model.names[c]}' if d.id is not None else self.model.names[c]
label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}') label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}')
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) if self.args.boxes else None self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) if self.args.boxes else None
if self.args.save_crop: if self.args.save_crop:
@ -97,9 +97,9 @@ class SegmentationPredictor(DetectionPredictor):
def predict(cfg=DEFAULT_CFG, use_python=False): def predict(cfg=DEFAULT_CFG, use_python=False):
model = cfg.model or "yolov8n-seg.pt" model = cfg.model or 'yolov8n-seg.pt'
source = cfg.source if cfg.source is not None else ROOT / "assets" if (ROOT / "assets").exists() \ source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
else "https://ultralytics.com/images/bus.jpg" else 'https://ultralytics.com/images/bus.jpg'
args = dict(model=model, source=source) args = dict(model=model, source=source)
if use_python: if use_python:
@ -110,5 +110,5 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
predictor.predict_cli() predictor.predict_cli()
if __name__ == "__main__": if __name__ == '__main__':
predict() predict()

@ -20,11 +20,11 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
def __init__(self, cfg=DEFAULT_CFG, overrides=None): def __init__(self, cfg=DEFAULT_CFG, overrides=None):
if overrides is None: if overrides is None:
overrides = {} overrides = {}
overrides["task"] = "segment" overrides['task'] = 'segment'
super().__init__(cfg, overrides) super().__init__(cfg, overrides)
def get_model(self, cfg=None, weights=None, verbose=True): def get_model(self, cfg=None, weights=None, verbose=True):
model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1) model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
if weights: if weights:
model.load(weights) model.load(weights)
@ -43,13 +43,13 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
return self.compute_loss(preds, batch) return self.compute_loss(preds, batch)
def plot_training_samples(self, batch, ni): def plot_training_samples(self, batch, ni):
images = batch["img"] images = batch['img']
masks = batch["masks"] masks = batch['masks']
cls = batch["cls"].squeeze(-1) cls = batch['cls'].squeeze(-1)
bboxes = batch["bboxes"] bboxes = batch['bboxes']
paths = batch["im_file"] paths = batch['im_file']
batch_idx = batch["batch_idx"] batch_idx = batch['batch_idx']
plot_images(images, batch_idx, cls, bboxes, masks, paths=paths, fname=self.save_dir / f"train_batch{ni}.jpg") plot_images(images, batch_idx, cls, bboxes, masks, paths=paths, fname=self.save_dir / f'train_batch{ni}.jpg')
def plot_metrics(self): def plot_metrics(self):
plot_results(file=self.csv, segment=True) # save results.png plot_results(file=self.csv, segment=True) # save results.png
@ -80,15 +80,15 @@ class SegLoss(Loss):
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# targets # targets
batch_idx = batch["batch_idx"].view(-1, 1) batch_idx = batch['batch_idx'].view(-1, 1)
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1) targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
masks = batch["masks"].to(self.device).float() masks = batch['masks'].to(self.device).float()
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0] masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
# pboxes # pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
@ -135,13 +135,13 @@ class SegLoss(Loss):
def single_mask_loss(self, gt_mask, pred, proto, xyxy, area): def single_mask_loss(self, gt_mask, pred, proto, xyxy, area):
# Mask loss for one image # Mask loss for one image
pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80) pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80)
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none") loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean() return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
def train(cfg=DEFAULT_CFG, use_python=False): def train(cfg=DEFAULT_CFG, use_python=False):
model = cfg.model or "yolov8n-seg.pt" model = cfg.model or 'yolov8n-seg.pt'
data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist") data = cfg.data or 'coco128-seg.yaml' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else '' device = cfg.device if cfg.device is not None else ''
args = dict(model=model, data=data, device=device) args = dict(model=model, data=data, device=device)
@ -153,5 +153,5 @@ def train(cfg=DEFAULT_CFG, use_python=False):
trainer.train() trainer.train()
if __name__ == "__main__": if __name__ == '__main__':
train() train()

@ -24,7 +24,7 @@ class SegmentationValidator(DetectionValidator):
def preprocess(self, batch): def preprocess(self, batch):
batch = super().preprocess(batch) batch = super().preprocess(batch)
batch["masks"] = batch["masks"].to(self.device).float() batch['masks'] = batch['masks'].to(self.device).float()
return batch return batch
def init_metrics(self, model): def init_metrics(self, model):
@ -37,8 +37,8 @@ class SegmentationValidator(DetectionValidator):
self.process = ops.process_mask # faster self.process = ops.process_mask # faster
def get_desc(self): def get_desc(self):
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P", return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P',
"R", "mAP50", "mAP50-95)") 'R', 'mAP50', 'mAP50-95)')
def postprocess(self, preds): def postprocess(self, preds):
p = ops.non_max_suppression(preds[0], p = ops.non_max_suppression(preds[0],
@ -55,11 +55,11 @@ class SegmentationValidator(DetectionValidator):
def update_metrics(self, preds, batch): def update_metrics(self, preds, batch):
# Metrics # Metrics
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])): for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
idx = batch["batch_idx"] == si idx = batch['batch_idx'] == si
cls = batch["cls"][idx] cls = batch['cls'][idx]
bbox = batch["bboxes"][idx] bbox = batch['bboxes'][idx]
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
shape = batch["ori_shape"][si] shape = batch['ori_shape'][si]
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
self.seen += 1 self.seen += 1
@ -74,23 +74,23 @@ class SegmentationValidator(DetectionValidator):
# Masks # Masks
midx = [si] if self.args.overlap_mask else idx midx = [si] if self.args.overlap_mask else idx
gt_masks = batch["masks"][midx] gt_masks = batch['masks'][midx]
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch["img"][si].shape[1:]) pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch['img'][si].shape[1:])
# Predictions # Predictions
if self.args.single_cls: if self.args.single_cls:
pred[:, 5] = 0 pred[:, 5] = 0
predn = pred.clone() predn = pred.clone()
ops.scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape, ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
ratio_pad=batch["ratio_pad"][si]) # native-space pred ratio_pad=batch['ratio_pad'][si]) # native-space pred
# Evaluate # Evaluate
if nl: if nl:
height, width = batch["img"].shape[2:] height, width = batch['img'].shape[2:]
tbox = ops.xywh2xyxy(bbox) * torch.tensor( tbox = ops.xywh2xyxy(bbox) * torch.tensor(
(width, height, width, height), device=self.device) # target boxes (width, height, width, height), device=self.device) # target boxes
ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape, ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
ratio_pad=batch["ratio_pad"][si]) # native-space labels ratio_pad=batch['ratio_pad'][si]) # native-space labels
labelsn = torch.cat((cls, tbox), 1) # native-space labels labelsn = torch.cat((cls, tbox), 1) # native-space labels
correct_bboxes = self._process_batch(predn, labelsn) correct_bboxes = self._process_batch(predn, labelsn)
# TODO: maybe remove these `self.` arguments as they already are member variable # TODO: maybe remove these `self.` arguments as they already are member variable
@ -112,11 +112,11 @@ class SegmentationValidator(DetectionValidator):
# Save # Save
if self.args.save_json: if self.args.save_json:
pred_masks = ops.scale_image(batch["img"][si].shape[1:], pred_masks = ops.scale_image(batch['img'][si].shape[1:],
pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
shape, shape,
ratio_pad=batch["ratio_pad"][si]) ratio_pad=batch['ratio_pad'][si])
self.pred_to_json(predn, batch["im_file"][si], pred_masks) self.pred_to_json(predn, batch['im_file'][si], pred_masks)
# if self.args.save_txt: # if self.args.save_txt:
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
@ -136,7 +136,7 @@ class SegmentationValidator(DetectionValidator):
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640) gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
gt_masks = torch.where(gt_masks == index, 1.0, 0.0) gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
if gt_masks.shape[1:] != pred_masks.shape[1:]: if gt_masks.shape[1:] != pred_masks.shape[1:]:
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0] gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode='bilinear', align_corners=False)[0]
gt_masks = gt_masks.gt_(0.5) gt_masks = gt_masks.gt_(0.5)
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1)) iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
else: # boxes else: # boxes
@ -158,20 +158,20 @@ class SegmentationValidator(DetectionValidator):
return torch.tensor(correct, dtype=torch.bool, device=detections.device) return torch.tensor(correct, dtype=torch.bool, device=detections.device)
def plot_val_samples(self, batch, ni): def plot_val_samples(self, batch, ni):
plot_images(batch["img"], plot_images(batch['img'],
batch["batch_idx"], batch['batch_idx'],
batch["cls"].squeeze(-1), batch['cls'].squeeze(-1),
batch["bboxes"], batch['bboxes'],
batch["masks"], batch['masks'],
paths=batch["im_file"], paths=batch['im_file'],
fname=self.save_dir / f"val_batch{ni}_labels.jpg", fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names) names=self.names)
def plot_predictions(self, batch, preds, ni): def plot_predictions(self, batch, preds, ni):
plot_images(batch["img"], plot_images(batch['img'],
*output_to_target(preds[0], max_det=15), *output_to_target(preds[0], max_det=15),
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks, torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
paths=batch["im_file"], paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg', fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names) # pred names=self.names) # pred
self.plot_masks.clear() self.plot_masks.clear()
@ -182,8 +182,8 @@ class SegmentationValidator(DetectionValidator):
from pycocotools.mask import encode # noqa from pycocotools.mask import encode # noqa
def single_encode(x): def single_encode(x):
rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0] rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0]
rle["counts"] = rle["counts"].decode("utf-8") rle['counts'] = rle['counts'].decode('utf-8')
return rle return rle
stem = Path(filename).stem stem = Path(filename).stem
@ -203,8 +203,8 @@ class SegmentationValidator(DetectionValidator):
def eval_json(self, stats): def eval_json(self, stats):
if self.args.save_json and self.is_coco and len(self.jdict): if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / "annotations/instances_val2017.json" # annotations anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
pred_json = self.save_dir / "predictions.json" # predictions pred_json = self.save_dir / 'predictions.json' # predictions
self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...') self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements('pycocotools>=2.0.6') check_requirements('pycocotools>=2.0.6')
@ -212,7 +212,7 @@ class SegmentationValidator(DetectionValidator):
from pycocotools.cocoeval import COCOeval # noqa from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json: for x in anno_json, pred_json:
assert x.is_file(), f"{x} file not found" assert x.is_file(), f'{x} file not found'
anno = COCO(str(anno_json)) # init annotations api anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]): for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]):
@ -231,8 +231,8 @@ class SegmentationValidator(DetectionValidator):
def val(cfg=DEFAULT_CFG, use_python=False): def val(cfg=DEFAULT_CFG, use_python=False):
model = cfg.model or "yolov8n-seg.pt" model = cfg.model or 'yolov8n-seg.pt'
data = cfg.data or "coco128-seg.yaml" data = cfg.data or 'coco128-seg.yaml'
args = dict(model=model, data=data) args = dict(model=model, data=data)
if use_python: if use_python:
@ -243,5 +243,5 @@ def val(cfg=DEFAULT_CFG, use_python=False):
validator(model=args['model']) validator(model=args['model'])
if __name__ == "__main__": if __name__ == '__main__':
val() val()

Loading…
Cancel
Save