Update .pre-commit-config.yaml (#1026)
				
					
				
			This commit is contained in:
		@ -1,8 +1,5 @@
 | 
			
		||||
# Define hooks for code formations
 | 
			
		||||
# Will be applied on any updated commit files if a user has installed and linked commit hook
 | 
			
		||||
 | 
			
		||||
default_language_version:
 | 
			
		||||
  python: python3.8
 | 
			
		||||
# Ultralytics YOLO 🚀, GPL-3.0 license
 | 
			
		||||
# Pre-commit hooks. For more information see https://github.com/pre-commit/pre-commit-hooks/blob/main/README.md
 | 
			
		||||
 | 
			
		||||
exclude: 'docs/'
 | 
			
		||||
# Define bot property if installed via https://github.com/marketplace/pre-commit-ci
 | 
			
		||||
@ -16,13 +13,13 @@ repos:
 | 
			
		||||
  - repo: https://github.com/pre-commit/pre-commit-hooks
 | 
			
		||||
    rev: v4.4.0
 | 
			
		||||
    hooks:
 | 
			
		||||
      # - id: end-of-file-fixer
 | 
			
		||||
      - id: end-of-file-fixer
 | 
			
		||||
      - id: trailing-whitespace
 | 
			
		||||
      - id: check-case-conflict
 | 
			
		||||
      - id: check-yaml
 | 
			
		||||
      - id: check-toml
 | 
			
		||||
      - id: pretty-format-json
 | 
			
		||||
      - id: check-docstring-first
 | 
			
		||||
      - id: double-quote-string-fixer
 | 
			
		||||
      - id: detect-private-key
 | 
			
		||||
 | 
			
		||||
  - repo: https://github.com/asottile/pyupgrade
 | 
			
		||||
    rev: v3.3.1
 | 
			
		||||
@ -64,7 +61,7 @@ repos:
 | 
			
		||||
    hooks:
 | 
			
		||||
      - id: codespell
 | 
			
		||||
        args:
 | 
			
		||||
          - --ignore-words-list=crate,nd,strack
 | 
			
		||||
          - --ignore-words-list=crate,nd,strack,dota
 | 
			
		||||
 | 
			
		||||
  #- repo: https://github.com/asottile/yesqa
 | 
			
		||||
  #  rev: v1.4.0
 | 
			
		||||
 | 
			
		||||
@ -31,8 +31,7 @@ RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics
 | 
			
		||||
# Install pip packages
 | 
			
		||||
COPY requirements.txt .
 | 
			
		||||
RUN python3 -m pip install --upgrade pip wheel
 | 
			
		||||
RUN pip install --no-cache ultralytics[export] albumentations comet gsutil notebook \
 | 
			
		||||
    # tensorflow tensorflowjs \
 | 
			
		||||
RUN pip install --no-cache ultralytics[export] albumentations comet gsutil notebook
 | 
			
		||||
 | 
			
		||||
# Set environment variables
 | 
			
		||||
ENV OMP_NUM_THREADS=1
 | 
			
		||||
 | 
			
		||||
@ -27,8 +27,6 @@ RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics
 | 
			
		||||
COPY requirements.txt .
 | 
			
		||||
RUN python3 -m pip install --upgrade pip wheel
 | 
			
		||||
RUN pip install --no-cache ultralytics albumentations gsutil notebook
 | 
			
		||||
    # coremltools onnx onnxruntime \
 | 
			
		||||
    # tensorflow-aarch64 tensorflowjs \
 | 
			
		||||
 | 
			
		||||
# Cleanup
 | 
			
		||||
ENV DEBIAN_FRONTEND teletype
 | 
			
		||||
 | 
			
		||||
@ -27,8 +27,7 @@ RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics
 | 
			
		||||
COPY requirements.txt .
 | 
			
		||||
RUN python3 -m pip install --upgrade pip wheel
 | 
			
		||||
RUN pip install --no-cache ultralytics[export] albumentations gsutil notebook \
 | 
			
		||||
    # tensorflow-cpu tensorflowjs \
 | 
			
		||||
    --extra-index-url https://download.pytorch.org/whl/cpu
 | 
			
		||||
        --extra-index-url https://download.pytorch.org/whl/cpu
 | 
			
		||||
 | 
			
		||||
# Cleanup
 | 
			
		||||
ENV DEBIAN_FRONTEND teletype
 | 
			
		||||
 | 
			
		||||
@ -92,7 +92,7 @@ Export a YOLOv8n model to a different format like ONNX, CoreML, etc.
 | 
			
		||||
 | 
			
		||||
## Overriding default arguments
 | 
			
		||||
 | 
			
		||||
Default arguments can be overriden by simply passing them as arguments in the CLI in `arg=value` pairs.
 | 
			
		||||
Default arguments can be overridden by simply passing them as arguments in the CLI in `arg=value` pairs.
 | 
			
		||||
 | 
			
		||||
!!! tip ""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -96,7 +96,7 @@ Class reference documentation for `Results` module and its components can be fou
 | 
			
		||||
 | 
			
		||||
## Visualizing results
 | 
			
		||||
 | 
			
		||||
You can use `visualize()` function of `Result` object to get a visualization. It plots all componenets(boxes, masks, classification logits, etc) found in the results object
 | 
			
		||||
You can use `visualize()` function of `Result` object to get a visualization. It plots all components(boxes, masks, classification logits, etc) found in the results object
 | 
			
		||||
```python
 | 
			
		||||
    res = model(img)
 | 
			
		||||
    res_plotted = res[0].visualize()
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,7 @@ The simplest way of simply using YOLOv8 directly in a Python environment.
 | 
			
		||||
 | 
			
		||||
!!! example "Train"
 | 
			
		||||
 | 
			
		||||
    === "From pretrained(recommanded)"
 | 
			
		||||
    === "From pretrained(recommended)"
 | 
			
		||||
        ```python
 | 
			
		||||
        from ultralytics import YOLO
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							@ -16,7 +16,7 @@ PKG_REQUIREMENTS = ['sentry_sdk']  # pip-only requirements
 | 
			
		||||
 | 
			
		||||
def get_version():
 | 
			
		||||
    file = PARENT / 'ultralytics/__init__.py'
 | 
			
		||||
    return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', file.read_text(encoding="utf-8"), re.M)[1]
 | 
			
		||||
    return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', file.read_text(encoding='utf-8'), re.M)[1]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
setup(
 | 
			
		||||
 | 
			
		||||
@ -49,9 +49,9 @@ def test_val_classify():
 | 
			
		||||
# Predict checks -------------------------------------------------------------------------------------------------------
 | 
			
		||||
def test_predict_detect():
 | 
			
		||||
    run(f"yolo predict model={MODEL}.pt source={ROOT / 'assets'} imgsz=32")
 | 
			
		||||
    run(f"yolo predict model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32")
 | 
			
		||||
    run(f"yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=32")
 | 
			
		||||
    run(f"yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_portrait_min.mov imgsz=32")
 | 
			
		||||
    run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32')
 | 
			
		||||
    run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=32')
 | 
			
		||||
    run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_portrait_min.mov imgsz=32')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_predict_segment():
 | 
			
		||||
 | 
			
		||||
@ -11,12 +11,12 @@ CFG_SEG = 'yolov8n-seg.yaml'
 | 
			
		||||
CFG_CLS = 'squeezenet1_0'
 | 
			
		||||
CFG = get_cfg(DEFAULT_CFG)
 | 
			
		||||
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n'
 | 
			
		||||
SOURCE = ROOT / "assets"
 | 
			
		||||
SOURCE = ROOT / 'assets'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_detect():
 | 
			
		||||
    overrides = {"data": "coco8.yaml", "model": CFG_DET, "imgsz": 32, "epochs": 1, "save": False}
 | 
			
		||||
    CFG.data = "coco8.yaml"
 | 
			
		||||
    overrides = {'data': 'coco8.yaml', 'model': CFG_DET, 'imgsz': 32, 'epochs': 1, 'save': False}
 | 
			
		||||
    CFG.data = 'coco8.yaml'
 | 
			
		||||
 | 
			
		||||
    # Trainer
 | 
			
		||||
    trainer = detect.DetectionTrainer(overrides=overrides)
 | 
			
		||||
@ -27,24 +27,24 @@ def test_detect():
 | 
			
		||||
    val(model=trainer.best)  # validate best.pt
 | 
			
		||||
 | 
			
		||||
    # Predictor
 | 
			
		||||
    pred = detect.DetectionPredictor(overrides={"imgsz": [64, 64]})
 | 
			
		||||
    result = pred(source=SOURCE, model=f"{MODEL}.pt")
 | 
			
		||||
    assert len(result), "predictor test failed"
 | 
			
		||||
    pred = detect.DetectionPredictor(overrides={'imgsz': [64, 64]})
 | 
			
		||||
    result = pred(source=SOURCE, model=f'{MODEL}.pt')
 | 
			
		||||
    assert len(result), 'predictor test failed'
 | 
			
		||||
 | 
			
		||||
    overrides["resume"] = trainer.last
 | 
			
		||||
    overrides['resume'] = trainer.last
 | 
			
		||||
    trainer = detect.DetectionTrainer(overrides=overrides)
 | 
			
		||||
    try:
 | 
			
		||||
        trainer.train()
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        print(f"Expected exception caught: {e}")
 | 
			
		||||
        print(f'Expected exception caught: {e}')
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    Exception("Resume test failed!")
 | 
			
		||||
    Exception('Resume test failed!')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_segment():
 | 
			
		||||
    overrides = {"data": "coco8-seg.yaml", "model": CFG_SEG, "imgsz": 32, "epochs": 1, "save": False}
 | 
			
		||||
    CFG.data = "coco8-seg.yaml"
 | 
			
		||||
    overrides = {'data': 'coco8-seg.yaml', 'model': CFG_SEG, 'imgsz': 32, 'epochs': 1, 'save': False}
 | 
			
		||||
    CFG.data = 'coco8-seg.yaml'
 | 
			
		||||
    CFG.v5loader = False
 | 
			
		||||
    # YOLO(CFG_SEG).train(**overrides)  # works
 | 
			
		||||
 | 
			
		||||
@ -57,25 +57,25 @@ def test_segment():
 | 
			
		||||
    val(model=trainer.best)  # validate best.pt
 | 
			
		||||
 | 
			
		||||
    # Predictor
 | 
			
		||||
    pred = segment.SegmentationPredictor(overrides={"imgsz": [64, 64]})
 | 
			
		||||
    result = pred(source=SOURCE, model=f"{MODEL}-seg.pt")
 | 
			
		||||
    assert len(result) == 2, "predictor test failed"
 | 
			
		||||
    pred = segment.SegmentationPredictor(overrides={'imgsz': [64, 64]})
 | 
			
		||||
    result = pred(source=SOURCE, model=f'{MODEL}-seg.pt')
 | 
			
		||||
    assert len(result) == 2, 'predictor test failed'
 | 
			
		||||
 | 
			
		||||
    # Test resume
 | 
			
		||||
    overrides["resume"] = trainer.last
 | 
			
		||||
    overrides['resume'] = trainer.last
 | 
			
		||||
    trainer = segment.SegmentationTrainer(overrides=overrides)
 | 
			
		||||
    try:
 | 
			
		||||
        trainer.train()
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        print(f"Expected exception caught: {e}")
 | 
			
		||||
        print(f'Expected exception caught: {e}')
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    Exception("Resume test failed!")
 | 
			
		||||
    Exception('Resume test failed!')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_classify():
 | 
			
		||||
    overrides = {"data": "mnist160", "model": "yolov8n-cls.yaml", "imgsz": 32, "epochs": 1, "batch": 64, "save": False}
 | 
			
		||||
    CFG.data = "mnist160"
 | 
			
		||||
    overrides = {'data': 'mnist160', 'model': 'yolov8n-cls.yaml', 'imgsz': 32, 'epochs': 1, 'batch': 64, 'save': False}
 | 
			
		||||
    CFG.data = 'mnist160'
 | 
			
		||||
    CFG.imgsz = 32
 | 
			
		||||
    CFG.batch = 64
 | 
			
		||||
    # YOLO(CFG_SEG).train(**overrides)  # works
 | 
			
		||||
@ -89,6 +89,6 @@ def test_classify():
 | 
			
		||||
    val(model=trainer.best)
 | 
			
		||||
 | 
			
		||||
    # Predictor
 | 
			
		||||
    pred = classify.ClassificationPredictor(overrides={"imgsz": [64, 64]})
 | 
			
		||||
    pred = classify.ClassificationPredictor(overrides={'imgsz': [64, 64]})
 | 
			
		||||
    result = pred(source=SOURCE, model=trainer.best)
 | 
			
		||||
    assert len(result) == 2, "predictor test failed"
 | 
			
		||||
    assert len(result) == 2, 'predictor test failed'
 | 
			
		||||
 | 
			
		||||
@ -37,24 +37,24 @@ def test_model_fuse():
 | 
			
		||||
 | 
			
		||||
def test_predict_dir():
 | 
			
		||||
    model = YOLO(MODEL)
 | 
			
		||||
    model(source=ROOT / "assets")
 | 
			
		||||
    model(source=ROOT / 'assets')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_predict_img():
 | 
			
		||||
    model = YOLO(MODEL)
 | 
			
		||||
    img = Image.open(str(SOURCE))
 | 
			
		||||
    output = model(source=img, save=True, verbose=True)  # PIL
 | 
			
		||||
    assert len(output) == 1, "predict test failed"
 | 
			
		||||
    assert len(output) == 1, 'predict test failed'
 | 
			
		||||
    img = cv2.imread(str(SOURCE))
 | 
			
		||||
    output = model(source=img, save=True, save_txt=True)  # ndarray
 | 
			
		||||
    assert len(output) == 1, "predict test failed"
 | 
			
		||||
    assert len(output) == 1, 'predict test failed'
 | 
			
		||||
    output = model(source=[img, img], save=True, save_txt=True)  # batch
 | 
			
		||||
    assert len(output) == 2, "predict test failed"
 | 
			
		||||
    assert len(output) == 2, 'predict test failed'
 | 
			
		||||
    output = model(source=[img, img], save=True, stream=True)  # stream
 | 
			
		||||
    assert len(list(output)) == 2, "predict test failed"
 | 
			
		||||
    assert len(list(output)) == 2, 'predict test failed'
 | 
			
		||||
    tens = torch.zeros(320, 640, 3)
 | 
			
		||||
    output = model(tens.numpy())
 | 
			
		||||
    assert len(output) == 1, "predict test failed"
 | 
			
		||||
    assert len(output) == 1, 'predict test failed'
 | 
			
		||||
    # test multiple source
 | 
			
		||||
    imgs = [
 | 
			
		||||
        SOURCE,  # filename
 | 
			
		||||
@ -64,23 +64,23 @@ def test_predict_img():
 | 
			
		||||
        Image.open(SOURCE),  # PIL
 | 
			
		||||
        np.zeros((320, 640, 3))]  # numpy
 | 
			
		||||
    output = model(imgs)
 | 
			
		||||
    assert len(output) == 6, "predict test failed!"
 | 
			
		||||
    assert len(output) == 6, 'predict test failed!'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_val():
 | 
			
		||||
    model = YOLO(MODEL)
 | 
			
		||||
    model.val(data="coco8.yaml", imgsz=32)
 | 
			
		||||
    model.val(data='coco8.yaml', imgsz=32)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_train_scratch():
 | 
			
		||||
    model = YOLO(CFG)
 | 
			
		||||
    model.train(data="coco8.yaml", epochs=1, imgsz=32)
 | 
			
		||||
    model.train(data='coco8.yaml', epochs=1, imgsz=32)
 | 
			
		||||
    model(SOURCE)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_train_pretrained():
 | 
			
		||||
    model = YOLO(MODEL)
 | 
			
		||||
    model.train(data="coco8.yaml", epochs=1, imgsz=32)
 | 
			
		||||
    model.train(data='coco8.yaml', epochs=1, imgsz=32)
 | 
			
		||||
    model(SOURCE)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -139,10 +139,10 @@ def test_all_model_yamls():
 | 
			
		||||
 | 
			
		||||
def test_workflow():
 | 
			
		||||
    model = YOLO(MODEL)
 | 
			
		||||
    model.train(data="coco8.yaml", epochs=1, imgsz=32)
 | 
			
		||||
    model.train(data='coco8.yaml', epochs=1, imgsz=32)
 | 
			
		||||
    model.val()
 | 
			
		||||
    model.predict(SOURCE)
 | 
			
		||||
    model.export(format="onnx")  # export a model to ONNX format
 | 
			
		||||
    model.export(format='onnx')  # export a model to ONNX format
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_predict_callback_and_setup():
 | 
			
		||||
@ -154,8 +154,8 @@ def test_predict_callback_and_setup():
 | 
			
		||||
        bs = [predictor.dataset.bs for _ in range(len(path))]
 | 
			
		||||
        predictor.results = zip(predictor.results, im0s, bs)
 | 
			
		||||
 | 
			
		||||
    model = YOLO("yolov8n.pt")
 | 
			
		||||
    model.add_callback("on_predict_batch_end", on_predict_batch_end)
 | 
			
		||||
    model = YOLO('yolov8n.pt')
 | 
			
		||||
    model.add_callback('on_predict_batch_end', on_predict_batch_end)
 | 
			
		||||
 | 
			
		||||
    dataset = load_inference_source(source=SOURCE, transforms=model.transforms)
 | 
			
		||||
    bs = dataset.bs  # noqa access predictor properties
 | 
			
		||||
@ -168,8 +168,8 @@ def test_predict_callback_and_setup():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_result():
 | 
			
		||||
    model = YOLO("yolov8n-seg.pt")
 | 
			
		||||
    img = str(ROOT / "assets/bus.jpg")
 | 
			
		||||
    model = YOLO('yolov8n-seg.pt')
 | 
			
		||||
    img = str(ROOT / 'assets/bus.jpg')
 | 
			
		||||
    res = model([img, img])
 | 
			
		||||
    res[0].numpy()
 | 
			
		||||
    res[0].cpu().numpy()
 | 
			
		||||
 | 
			
		||||
@ -1,8 +1,8 @@
 | 
			
		||||
# Ultralytics YOLO 🚀, GPL-3.0 license
 | 
			
		||||
 | 
			
		||||
__version__ = "8.0.40"
 | 
			
		||||
__version__ = '8.0.40'
 | 
			
		||||
 | 
			
		||||
from ultralytics.yolo.engine.model import YOLO
 | 
			
		||||
from ultralytics.yolo.utils.checks import check_yolo as checks
 | 
			
		||||
 | 
			
		||||
__all__ = ["__version__", "YOLO", "checks"]  # allow simpler import
 | 
			
		||||
__all__ = ['__version__', 'YOLO', 'checks']  # allow simpler import
 | 
			
		||||
 | 
			
		||||
@ -10,10 +10,10 @@ from ultralytics.yolo.engine.model import YOLO
 | 
			
		||||
from ultralytics.yolo.utils import LOGGER, PREFIX, emojis
 | 
			
		||||
 | 
			
		||||
# Define all export formats
 | 
			
		||||
EXPORT_FORMATS_HUB = EXPORT_FORMATS_LIST + ["ultralytics_tflite", "ultralytics_coreml"]
 | 
			
		||||
EXPORT_FORMATS_HUB = EXPORT_FORMATS_LIST + ['ultralytics_tflite', 'ultralytics_coreml']
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def start(key=""):
 | 
			
		||||
def start(key=''):
 | 
			
		||||
    """
 | 
			
		||||
    Start training models with Ultralytics HUB. Usage: from src.ultralytics import start; start('API_KEY')
 | 
			
		||||
    """
 | 
			
		||||
@ -34,7 +34,7 @@ def start(key=""):
 | 
			
		||||
        session.register_callbacks(trainer)
 | 
			
		||||
        trainer.train(**session.train_args)
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        LOGGER.warning(f"{PREFIX}{e}")
 | 
			
		||||
        LOGGER.warning(f'{PREFIX}{e}')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def request_api_key(auth, max_attempts=3):
 | 
			
		||||
@ -43,56 +43,56 @@ def request_api_key(auth, max_attempts=3):
 | 
			
		||||
    """
 | 
			
		||||
    import getpass
 | 
			
		||||
    for attempts in range(max_attempts):
 | 
			
		||||
        LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}")
 | 
			
		||||
        input_key = getpass.getpass("Enter your Ultralytics HUB API key:\n")
 | 
			
		||||
        LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}')
 | 
			
		||||
        input_key = getpass.getpass('Enter your Ultralytics HUB API key:\n')
 | 
			
		||||
        auth.api_key, model_id = split_key(input_key)
 | 
			
		||||
 | 
			
		||||
        if auth.authenticate():
 | 
			
		||||
            LOGGER.info(f"{PREFIX}Authenticated ✅")
 | 
			
		||||
            LOGGER.info(f'{PREFIX}Authenticated ✅')
 | 
			
		||||
            return model_id
 | 
			
		||||
 | 
			
		||||
        LOGGER.warning(f"{PREFIX}Invalid API key ⚠️\n")
 | 
			
		||||
        LOGGER.warning(f'{PREFIX}Invalid API key ⚠️\n')
 | 
			
		||||
 | 
			
		||||
    raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
 | 
			
		||||
    raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌'))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def reset_model(key=""):
 | 
			
		||||
def reset_model(key=''):
 | 
			
		||||
    # Reset a trained model to an untrained state
 | 
			
		||||
    api_key, model_id = split_key(key)
 | 
			
		||||
    r = requests.post("https://api.ultralytics.com/model-reset", json={"apiKey": api_key, "modelId": model_id})
 | 
			
		||||
    r = requests.post('https://api.ultralytics.com/model-reset', json={'apiKey': api_key, 'modelId': model_id})
 | 
			
		||||
 | 
			
		||||
    if r.status_code == 200:
 | 
			
		||||
        LOGGER.info(f"{PREFIX}model reset successfully")
 | 
			
		||||
        LOGGER.info(f'{PREFIX}model reset successfully')
 | 
			
		||||
        return
 | 
			
		||||
    LOGGER.warning(f"{PREFIX}model reset failure {r.status_code} {r.reason}")
 | 
			
		||||
    LOGGER.warning(f'{PREFIX}model reset failure {r.status_code} {r.reason}')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def export_model(key="", format="torchscript"):
 | 
			
		||||
def export_model(key='', format='torchscript'):
 | 
			
		||||
    # Export a model to all formats
 | 
			
		||||
    assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}"
 | 
			
		||||
    api_key, model_id = split_key(key)
 | 
			
		||||
    r = requests.post("https://api.ultralytics.com/export",
 | 
			
		||||
    r = requests.post('https://api.ultralytics.com/export',
 | 
			
		||||
                      json={
 | 
			
		||||
                          "apiKey": api_key,
 | 
			
		||||
                          "modelId": model_id,
 | 
			
		||||
                          "format": format})
 | 
			
		||||
    assert (r.status_code == 200), f"{PREFIX}{format} export failure {r.status_code} {r.reason}"
 | 
			
		||||
    LOGGER.info(f"{PREFIX}{format} export started ✅")
 | 
			
		||||
                          'apiKey': api_key,
 | 
			
		||||
                          'modelId': model_id,
 | 
			
		||||
                          'format': format})
 | 
			
		||||
    assert (r.status_code == 200), f'{PREFIX}{format} export failure {r.status_code} {r.reason}'
 | 
			
		||||
    LOGGER.info(f'{PREFIX}{format} export started ✅')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_export(key="", format="torchscript"):
 | 
			
		||||
def get_export(key='', format='torchscript'):
 | 
			
		||||
    # Get an exported model dictionary with download URL
 | 
			
		||||
    assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}"
 | 
			
		||||
    api_key, model_id = split_key(key)
 | 
			
		||||
    r = requests.post("https://api.ultralytics.com/get-export",
 | 
			
		||||
    r = requests.post('https://api.ultralytics.com/get-export',
 | 
			
		||||
                      json={
 | 
			
		||||
                          "apiKey": api_key,
 | 
			
		||||
                          "modelId": model_id,
 | 
			
		||||
                          "format": format})
 | 
			
		||||
    assert (r.status_code == 200), f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}"
 | 
			
		||||
                          'apiKey': api_key,
 | 
			
		||||
                          'modelId': model_id,
 | 
			
		||||
                          'format': format})
 | 
			
		||||
    assert (r.status_code == 200), f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}'
 | 
			
		||||
    return r.json()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# temp. For checking
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    start()
 | 
			
		||||
 | 
			
		||||
@ -5,7 +5,7 @@ import requests
 | 
			
		||||
from ultralytics.hub.utils import HUB_API_ROOT, request_with_credentials
 | 
			
		||||
from ultralytics.yolo.utils import is_colab
 | 
			
		||||
 | 
			
		||||
API_KEY_PATH = "https://hub.ultralytics.com/settings?tab=api+keys"
 | 
			
		||||
API_KEY_PATH = 'https://hub.ultralytics.com/settings?tab=api+keys'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Auth:
 | 
			
		||||
@ -18,7 +18,7 @@ class Auth:
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _clean_api_key(key: str) -> str:
 | 
			
		||||
        """Strip model from key if present"""
 | 
			
		||||
        separator = "_"
 | 
			
		||||
        separator = '_'
 | 
			
		||||
        return key.split(separator)[0] if separator in key else key
 | 
			
		||||
 | 
			
		||||
    def authenticate(self) -> bool:
 | 
			
		||||
@ -26,11 +26,11 @@ class Auth:
 | 
			
		||||
        try:
 | 
			
		||||
            header = self.get_auth_header()
 | 
			
		||||
            if header:
 | 
			
		||||
                r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header)
 | 
			
		||||
                r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header)
 | 
			
		||||
                if not r.json().get('success', False):
 | 
			
		||||
                    raise ConnectionError("Unable to authenticate.")
 | 
			
		||||
                    raise ConnectionError('Unable to authenticate.')
 | 
			
		||||
                return True
 | 
			
		||||
            raise ConnectionError("User has not authenticated locally.")
 | 
			
		||||
            raise ConnectionError('User has not authenticated locally.')
 | 
			
		||||
        except ConnectionError:
 | 
			
		||||
            self.id_token = self.api_key = False  # reset invalid
 | 
			
		||||
            return False
 | 
			
		||||
@ -43,21 +43,21 @@ class Auth:
 | 
			
		||||
        if not is_colab():
 | 
			
		||||
            return False  # Currently only works with Colab
 | 
			
		||||
        try:
 | 
			
		||||
            authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto")
 | 
			
		||||
            if authn.get("success", False):
 | 
			
		||||
                self.id_token = authn.get("data", {}).get("idToken", None)
 | 
			
		||||
            authn = request_with_credentials(f'{HUB_API_ROOT}/v1/auth/auto')
 | 
			
		||||
            if authn.get('success', False):
 | 
			
		||||
                self.id_token = authn.get('data', {}).get('idToken', None)
 | 
			
		||||
                self.authenticate()
 | 
			
		||||
                return True
 | 
			
		||||
            raise ConnectionError("Unable to fetch browser authentication details.")
 | 
			
		||||
            raise ConnectionError('Unable to fetch browser authentication details.')
 | 
			
		||||
        except ConnectionError:
 | 
			
		||||
            self.id_token = False  # reset invalid
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
    def get_auth_header(self):
 | 
			
		||||
        if self.id_token:
 | 
			
		||||
            return {"authorization": f"Bearer {self.id_token}"}
 | 
			
		||||
            return {'authorization': f'Bearer {self.id_token}'}
 | 
			
		||||
        elif self.api_key:
 | 
			
		||||
            return {"x-api-key": self.api_key}
 | 
			
		||||
            return {'x-api-key': self.api_key}
 | 
			
		||||
        else:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -11,7 +11,7 @@ from ultralytics.hub.utils import HUB_API_ROOT, check_dataset_disk_space, smart_
 | 
			
		||||
from ultralytics.yolo.utils import LOGGER, PREFIX, __version__, emojis, is_colab, threaded
 | 
			
		||||
from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
 | 
			
		||||
 | 
			
		||||
AGENT_NAME = f"python-{__version__}-colab" if is_colab() else f"python-{__version__}-local"
 | 
			
		||||
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
 | 
			
		||||
session = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -20,9 +20,9 @@ class HubTrainingSession:
 | 
			
		||||
    def __init__(self, model_id, auth):
 | 
			
		||||
        self.agent_id = None  # identifies which instance is communicating with server
 | 
			
		||||
        self.model_id = model_id
 | 
			
		||||
        self.api_url = f"{HUB_API_ROOT}/v1/models/{model_id}"
 | 
			
		||||
        self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
 | 
			
		||||
        self.auth_header = auth.get_auth_header()
 | 
			
		||||
        self._rate_limits = {"metrics": 3.0, "ckpt": 900.0, "heartbeat": 300.0}  # rate limits (seconds)
 | 
			
		||||
        self._rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0}  # rate limits (seconds)
 | 
			
		||||
        self._timers = {}  # rate limit timers (seconds)
 | 
			
		||||
        self._metrics_queue = {}  # metrics queue
 | 
			
		||||
        self.model = self._get_model()
 | 
			
		||||
@ -40,7 +40,7 @@ class HubTrainingSession:
 | 
			
		||||
        passed by signal.
 | 
			
		||||
        """
 | 
			
		||||
        if self.alive is True:
 | 
			
		||||
            LOGGER.info(f"{PREFIX}Kill signal received! ❌")
 | 
			
		||||
            LOGGER.info(f'{PREFIX}Kill signal received! ❌')
 | 
			
		||||
            self._stop_heartbeat()
 | 
			
		||||
            sys.exit(signum)
 | 
			
		||||
 | 
			
		||||
@ -49,23 +49,23 @@ class HubTrainingSession:
 | 
			
		||||
        self.alive = False
 | 
			
		||||
 | 
			
		||||
    def upload_metrics(self):
 | 
			
		||||
        payload = {"metrics": self._metrics_queue.copy(), "type": "metrics"}
 | 
			
		||||
        smart_request(f"{self.api_url}", json=payload, headers=self.auth_header, code=2)
 | 
			
		||||
        payload = {'metrics': self._metrics_queue.copy(), 'type': 'metrics'}
 | 
			
		||||
        smart_request(f'{self.api_url}', json=payload, headers=self.auth_header, code=2)
 | 
			
		||||
 | 
			
		||||
    def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
 | 
			
		||||
        # Upload a model to HUB
 | 
			
		||||
        file = None
 | 
			
		||||
        if Path(weights).is_file():
 | 
			
		||||
            with open(weights, "rb") as f:
 | 
			
		||||
            with open(weights, 'rb') as f:
 | 
			
		||||
                file = f.read()
 | 
			
		||||
        if final:
 | 
			
		||||
            smart_request(
 | 
			
		||||
                f"{self.api_url}/upload",
 | 
			
		||||
                f'{self.api_url}/upload',
 | 
			
		||||
                data={
 | 
			
		||||
                    "epoch": epoch,
 | 
			
		||||
                    "type": "final",
 | 
			
		||||
                    "map": map},
 | 
			
		||||
                files={"best.pt": file},
 | 
			
		||||
                    'epoch': epoch,
 | 
			
		||||
                    'type': 'final',
 | 
			
		||||
                    'map': map},
 | 
			
		||||
                files={'best.pt': file},
 | 
			
		||||
                headers=self.auth_header,
 | 
			
		||||
                retry=10,
 | 
			
		||||
                timeout=3600,
 | 
			
		||||
@ -73,66 +73,66 @@ class HubTrainingSession:
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            smart_request(
 | 
			
		||||
                f"{self.api_url}/upload",
 | 
			
		||||
                f'{self.api_url}/upload',
 | 
			
		||||
                data={
 | 
			
		||||
                    "epoch": epoch,
 | 
			
		||||
                    "type": "epoch",
 | 
			
		||||
                    "isBest": bool(is_best)},
 | 
			
		||||
                    'epoch': epoch,
 | 
			
		||||
                    'type': 'epoch',
 | 
			
		||||
                    'isBest': bool(is_best)},
 | 
			
		||||
                headers=self.auth_header,
 | 
			
		||||
                files={"last.pt": file},
 | 
			
		||||
                files={'last.pt': file},
 | 
			
		||||
                code=3,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def _get_model(self):
 | 
			
		||||
        # Returns model from database by id
 | 
			
		||||
        api_url = f"{HUB_API_ROOT}/v1/models/{self.model_id}"
 | 
			
		||||
        api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
 | 
			
		||||
        headers = self.auth_header
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            response = smart_request(api_url, method="get", headers=headers, thread=False, code=0)
 | 
			
		||||
            data = response.json().get("data", None)
 | 
			
		||||
            response = smart_request(api_url, method='get', headers=headers, thread=False, code=0)
 | 
			
		||||
            data = response.json().get('data', None)
 | 
			
		||||
 | 
			
		||||
            if data.get("status", None) == "trained":
 | 
			
		||||
            if data.get('status', None) == 'trained':
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    emojis(f"Model is already trained and uploaded to "
 | 
			
		||||
                           f"https://hub.ultralytics.com/models/{self.model_id} 🚀"))
 | 
			
		||||
                    emojis(f'Model is already trained and uploaded to '
 | 
			
		||||
                           f'https://hub.ultralytics.com/models/{self.model_id} 🚀'))
 | 
			
		||||
 | 
			
		||||
            if not data.get("data", None):
 | 
			
		||||
                raise ValueError("Dataset may still be processing. Please wait a minute and try again.")  # RF fix
 | 
			
		||||
            self.model_id = data["id"]
 | 
			
		||||
            if not data.get('data', None):
 | 
			
		||||
                raise ValueError('Dataset may still be processing. Please wait a minute and try again.')  # RF fix
 | 
			
		||||
            self.model_id = data['id']
 | 
			
		||||
 | 
			
		||||
            # TODO: restore when server keys when dataset URL and GPU train is working
 | 
			
		||||
 | 
			
		||||
            self.train_args = {
 | 
			
		||||
                "batch": data["batch_size"],
 | 
			
		||||
                "epochs": data["epochs"],
 | 
			
		||||
                "imgsz": data["imgsz"],
 | 
			
		||||
                "patience": data["patience"],
 | 
			
		||||
                "device": data["device"],
 | 
			
		||||
                "cache": data["cache"],
 | 
			
		||||
                "data": data["data"]}
 | 
			
		||||
                'batch': data['batch_size'],
 | 
			
		||||
                'epochs': data['epochs'],
 | 
			
		||||
                'imgsz': data['imgsz'],
 | 
			
		||||
                'patience': data['patience'],
 | 
			
		||||
                'device': data['device'],
 | 
			
		||||
                'cache': data['cache'],
 | 
			
		||||
                'data': data['data']}
 | 
			
		||||
 | 
			
		||||
            self.input_file = data.get("cfg", data["weights"])
 | 
			
		||||
            self.input_file = data.get('cfg', data['weights'])
 | 
			
		||||
 | 
			
		||||
            # hack for yolov5 cfg adds u
 | 
			
		||||
            if "cfg" in data and "yolov5" in data["cfg"]:
 | 
			
		||||
                self.input_file = data["cfg"].replace(".yaml", "u.yaml")
 | 
			
		||||
            if 'cfg' in data and 'yolov5' in data['cfg']:
 | 
			
		||||
                self.input_file = data['cfg'].replace('.yaml', 'u.yaml')
 | 
			
		||||
 | 
			
		||||
            return data
 | 
			
		||||
        except requests.exceptions.ConnectionError as e:
 | 
			
		||||
            raise ConnectionRefusedError("ERROR: The HUB server is not online. Please try again later.") from e
 | 
			
		||||
            raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e
 | 
			
		||||
        except Exception:
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
    def check_disk_space(self):
 | 
			
		||||
        if not check_dataset_disk_space(self.model["data"]):
 | 
			
		||||
            raise MemoryError("Not enough disk space")
 | 
			
		||||
        if not check_dataset_disk_space(self.model['data']):
 | 
			
		||||
            raise MemoryError('Not enough disk space')
 | 
			
		||||
 | 
			
		||||
    def register_callbacks(self, trainer):
 | 
			
		||||
        trainer.add_callback("on_pretrain_routine_end", self.on_pretrain_routine_end)
 | 
			
		||||
        trainer.add_callback("on_fit_epoch_end", self.on_fit_epoch_end)
 | 
			
		||||
        trainer.add_callback("on_model_save", self.on_model_save)
 | 
			
		||||
        trainer.add_callback("on_train_end", self.on_train_end)
 | 
			
		||||
        trainer.add_callback('on_pretrain_routine_end', self.on_pretrain_routine_end)
 | 
			
		||||
        trainer.add_callback('on_fit_epoch_end', self.on_fit_epoch_end)
 | 
			
		||||
        trainer.add_callback('on_model_save', self.on_model_save)
 | 
			
		||||
        trainer.add_callback('on_train_end', self.on_train_end)
 | 
			
		||||
 | 
			
		||||
    def on_pretrain_routine_end(self, trainer):
 | 
			
		||||
        """
 | 
			
		||||
@ -140,57 +140,57 @@ class HubTrainingSession:
 | 
			
		||||
        This method does not use trainer. It is passed to all callbacks by default.
 | 
			
		||||
        """
 | 
			
		||||
        # Start timer for upload rate limit
 | 
			
		||||
        LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀")
 | 
			
		||||
        self._timers = {"metrics": time(), "ckpt": time()}  # start timer on self.rate_limit
 | 
			
		||||
        LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀')
 | 
			
		||||
        self._timers = {'metrics': time(), 'ckpt': time()}  # start timer on self.rate_limit
 | 
			
		||||
 | 
			
		||||
    def on_fit_epoch_end(self, trainer):
 | 
			
		||||
        # Upload metrics after val end
 | 
			
		||||
        all_plots = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics}
 | 
			
		||||
        all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics}
 | 
			
		||||
 | 
			
		||||
        if trainer.epoch == 0:
 | 
			
		||||
            model_info = {
 | 
			
		||||
                "model/parameters": get_num_params(trainer.model),
 | 
			
		||||
                "model/GFLOPs": round(get_flops(trainer.model), 3),
 | 
			
		||||
                "model/speed(ms)": round(trainer.validator.speed[1], 3)}
 | 
			
		||||
                'model/parameters': get_num_params(trainer.model),
 | 
			
		||||
                'model/GFLOPs': round(get_flops(trainer.model), 3),
 | 
			
		||||
                'model/speed(ms)': round(trainer.validator.speed[1], 3)}
 | 
			
		||||
            all_plots = {**all_plots, **model_info}
 | 
			
		||||
        self._metrics_queue[trainer.epoch] = json.dumps(all_plots)
 | 
			
		||||
        if time() - self._timers["metrics"] > self._rate_limits["metrics"]:
 | 
			
		||||
        if time() - self._timers['metrics'] > self._rate_limits['metrics']:
 | 
			
		||||
            self.upload_metrics()
 | 
			
		||||
            self._timers["metrics"] = time()  # reset timer
 | 
			
		||||
            self._timers['metrics'] = time()  # reset timer
 | 
			
		||||
            self._metrics_queue = {}  # reset queue
 | 
			
		||||
 | 
			
		||||
    def on_model_save(self, trainer):
 | 
			
		||||
        # Upload checkpoints with rate limiting
 | 
			
		||||
        is_best = trainer.best_fitness == trainer.fitness
 | 
			
		||||
        if time() - self._timers["ckpt"] > self._rate_limits["ckpt"]:
 | 
			
		||||
            LOGGER.info(f"{PREFIX}Uploading checkpoint {self.model_id}")
 | 
			
		||||
        if time() - self._timers['ckpt'] > self._rate_limits['ckpt']:
 | 
			
		||||
            LOGGER.info(f'{PREFIX}Uploading checkpoint {self.model_id}')
 | 
			
		||||
            self._upload_model(trainer.epoch, trainer.last, is_best)
 | 
			
		||||
            self._timers["ckpt"] = time()  # reset timer
 | 
			
		||||
            self._timers['ckpt'] = time()  # reset timer
 | 
			
		||||
 | 
			
		||||
    def on_train_end(self, trainer):
 | 
			
		||||
        # Upload final model and metrics with exponential standoff
 | 
			
		||||
        LOGGER.info(f"{PREFIX}Training completed successfully ✅")
 | 
			
		||||
        LOGGER.info(f"{PREFIX}Uploading final {self.model_id}")
 | 
			
		||||
        LOGGER.info(f'{PREFIX}Training completed successfully ✅')
 | 
			
		||||
        LOGGER.info(f'{PREFIX}Uploading final {self.model_id}')
 | 
			
		||||
 | 
			
		||||
        # hack for fetching mAP
 | 
			
		||||
        mAP = trainer.metrics.get("metrics/mAP50-95(B)", 0)
 | 
			
		||||
        mAP = trainer.metrics.get('metrics/mAP50-95(B)', 0)
 | 
			
		||||
        self._upload_model(trainer.epoch, trainer.best, map=mAP, final=True)  # results[3] is mAP0.5:0.95
 | 
			
		||||
        self.alive = False  # stop heartbeats
 | 
			
		||||
        LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀")
 | 
			
		||||
        LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀')
 | 
			
		||||
 | 
			
		||||
    def _upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
 | 
			
		||||
        # Upload a model to HUB
 | 
			
		||||
        file = None
 | 
			
		||||
        if Path(weights).is_file():
 | 
			
		||||
            with open(weights, "rb") as f:
 | 
			
		||||
            with open(weights, 'rb') as f:
 | 
			
		||||
                file = f.read()
 | 
			
		||||
        file_param = {"best.pt" if final else "last.pt": file}
 | 
			
		||||
        endpoint = f"{self.api_url}/upload"
 | 
			
		||||
        data = {"epoch": epoch}
 | 
			
		||||
        file_param = {'best.pt' if final else 'last.pt': file}
 | 
			
		||||
        endpoint = f'{self.api_url}/upload'
 | 
			
		||||
        data = {'epoch': epoch}
 | 
			
		||||
        if final:
 | 
			
		||||
            data.update({"type": "final", "map": map})
 | 
			
		||||
            data.update({'type': 'final', 'map': map})
 | 
			
		||||
        else:
 | 
			
		||||
            data.update({"type": "epoch", "isBest": bool(is_best)})
 | 
			
		||||
            data.update({'type': 'epoch', 'isBest': bool(is_best)})
 | 
			
		||||
 | 
			
		||||
        smart_request(
 | 
			
		||||
            endpoint,
 | 
			
		||||
@ -207,14 +207,14 @@ class HubTrainingSession:
 | 
			
		||||
        self.alive = True
 | 
			
		||||
        while self.alive:
 | 
			
		||||
            r = smart_request(
 | 
			
		||||
                f"{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}",
 | 
			
		||||
                f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
 | 
			
		||||
                json={
 | 
			
		||||
                    "agent": AGENT_NAME,
 | 
			
		||||
                    "agentId": self.agent_id},
 | 
			
		||||
                    'agent': AGENT_NAME,
 | 
			
		||||
                    'agentId': self.agent_id},
 | 
			
		||||
                headers=self.auth_header,
 | 
			
		||||
                retry=0,
 | 
			
		||||
                code=5,
 | 
			
		||||
                thread=False,
 | 
			
		||||
            )
 | 
			
		||||
            self.agent_id = r.json().get("data", {}).get("agentId", None)
 | 
			
		||||
            sleep(self._rate_limits["heartbeat"])
 | 
			
		||||
            self.agent_id = r.json().get('data', {}).get('agentId', None)
 | 
			
		||||
            sleep(self._rate_limits['heartbeat'])
 | 
			
		||||
 | 
			
		||||
@ -18,14 +18,14 @@ from ultralytics.yolo.utils.checks import check_online
 | 
			
		||||
 | 
			
		||||
PREFIX = colorstr('Ultralytics: ')
 | 
			
		||||
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
 | 
			
		||||
HUB_API_ROOT = os.environ.get("ULTRALYTICS_HUB_API", "https://api.ultralytics.com")
 | 
			
		||||
HUB_API_ROOT = os.environ.get('ULTRALYTICS_HUB_API', 'https://api.ultralytics.com')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_dataset_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=2.0):
 | 
			
		||||
    # Check that url fits on disk with safety factor sf, i.e. require 2GB free if url size is 1GB with sf=2.0
 | 
			
		||||
    gib = 1 << 30  # bytes per GiB
 | 
			
		||||
    data = int(requests.head(url).headers['Content-Length']) / gib  # dataset size (GB)
 | 
			
		||||
    total, used, free = (x / gib for x in shutil.disk_usage("/"))  # bytes
 | 
			
		||||
    total, used, free = (x / gib for x in shutil.disk_usage('/'))  # bytes
 | 
			
		||||
    LOGGER.info(f'{PREFIX}{data:.3f} GB dataset, {free:.1f}/{total:.1f} GB free disk space')
 | 
			
		||||
    if data * sf < free:
 | 
			
		||||
        return True  # sufficient space
 | 
			
		||||
@ -57,7 +57,7 @@ def request_with_credentials(url: str) -> any:
 | 
			
		||||
                });
 | 
			
		||||
            });
 | 
			
		||||
            """ % url))
 | 
			
		||||
    return output.eval_js("_hub_tmp")
 | 
			
		||||
    return output.eval_js('_hub_tmp')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Deprecated TODO: eliminate this function?
 | 
			
		||||
@ -84,7 +84,7 @@ def split_key(key=''):
 | 
			
		||||
    return api_key, model_id
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method="post", verbose=True, **kwargs):
 | 
			
		||||
def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method='post', verbose=True, **kwargs):
 | 
			
		||||
    """
 | 
			
		||||
    Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
 | 
			
		||||
 | 
			
		||||
@ -128,7 +128,7 @@ def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method="post
 | 
			
		||||
                    m = f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " \
 | 
			
		||||
                        f"Please retry after {h['Retry-After']}s."
 | 
			
		||||
                if verbose:
 | 
			
		||||
                    LOGGER.warning(f"{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})")
 | 
			
		||||
                    LOGGER.warning(f'{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})')
 | 
			
		||||
                if r.status_code not in retry_codes:
 | 
			
		||||
                    return r
 | 
			
		||||
            time.sleep(2 ** i)  # exponential standoff
 | 
			
		||||
@ -149,17 +149,17 @@ class Traces:
 | 
			
		||||
        self.rate_limit = 3.0  # rate limit (seconds)
 | 
			
		||||
        self.t = 0.0  # rate limit timer (seconds)
 | 
			
		||||
        self.metadata = {
 | 
			
		||||
            "sys_argv_name": Path(sys.argv[0]).name,
 | 
			
		||||
            "install": 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
 | 
			
		||||
            "python": platform.python_version(),
 | 
			
		||||
            "release": __version__,
 | 
			
		||||
            "environment": ENVIRONMENT}
 | 
			
		||||
            'sys_argv_name': Path(sys.argv[0]).name,
 | 
			
		||||
            'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
 | 
			
		||||
            'python': platform.python_version(),
 | 
			
		||||
            'release': __version__,
 | 
			
		||||
            'environment': ENVIRONMENT}
 | 
			
		||||
        self.enabled = SETTINGS['sync'] and \
 | 
			
		||||
                       RANK in {-1, 0} and \
 | 
			
		||||
                       check_online() and \
 | 
			
		||||
                       not is_pytest_running() and \
 | 
			
		||||
                       not is_github_actions_ci() and \
 | 
			
		||||
                       (is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git")
 | 
			
		||||
                       (is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git')
 | 
			
		||||
 | 
			
		||||
    def __call__(self, cfg, all_keys=False, traces_sample_rate=1.0):
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
@ -41,4 +41,4 @@ head:
 | 
			
		||||
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
 | 
			
		||||
 | 
			
		||||
   [[17, 20, 23], 1, Detect, [nc]],  # Detect(P3, P4, P5)
 | 
			
		||||
  ]
 | 
			
		||||
  ]
 | 
			
		||||
 | 
			
		||||
@ -41,4 +41,4 @@ head:
 | 
			
		||||
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
 | 
			
		||||
 | 
			
		||||
   [[17, 20, 23], 1, Detect, [nc]],  # Detect(P3, P4, P5)
 | 
			
		||||
  ]
 | 
			
		||||
  ]
 | 
			
		||||
 | 
			
		||||
@ -41,4 +41,4 @@ head:
 | 
			
		||||
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
 | 
			
		||||
 | 
			
		||||
   [[17, 20, 23], 1, Detect, [nc]],  # Detect(P3, P4, P5)
 | 
			
		||||
  ]
 | 
			
		||||
  ]
 | 
			
		||||
 | 
			
		||||
@ -42,4 +42,4 @@ head:
 | 
			
		||||
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
 | 
			
		||||
 | 
			
		||||
   [[17, 20, 23], 1, Detect, [nc]],  # Detect(P3, P4, P5)
 | 
			
		||||
  ]
 | 
			
		||||
  ]
 | 
			
		||||
 | 
			
		||||
@ -41,4 +41,4 @@ head:
 | 
			
		||||
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
 | 
			
		||||
 | 
			
		||||
   [[17, 20, 23], 1, Detect, [nc]],  # Detect(P3, P4, P5)
 | 
			
		||||
  ]
 | 
			
		||||
  ]
 | 
			
		||||
 | 
			
		||||
@ -127,11 +127,11 @@ class AutoBackend(nn.Module):
 | 
			
		||||
                w = next(Path(w).glob('*.xml'))  # get *.xml file from *_openvino_model dir
 | 
			
		||||
            network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin'))
 | 
			
		||||
            if network.get_parameters()[0].get_layout().empty:
 | 
			
		||||
                network.get_parameters()[0].set_layout(Layout("NCHW"))
 | 
			
		||||
                network.get_parameters()[0].set_layout(Layout('NCHW'))
 | 
			
		||||
            batch_dim = get_batch(network)
 | 
			
		||||
            if batch_dim.is_static:
 | 
			
		||||
                batch_size = batch_dim.get_length()
 | 
			
		||||
            executable_network = ie.compile_model(network, device_name="CPU")  # device_name="MYRIAD" for Intel NCS2
 | 
			
		||||
            executable_network = ie.compile_model(network, device_name='CPU')  # device_name="MYRIAD" for Intel NCS2
 | 
			
		||||
        elif engine:  # TensorRT
 | 
			
		||||
            LOGGER.info(f'Loading {w} for TensorRT inference...')
 | 
			
		||||
            import tensorrt as trt  # https://developer.nvidia.com/nvidia-tensorrt-download
 | 
			
		||||
@ -184,7 +184,7 @@ class AutoBackend(nn.Module):
 | 
			
		||||
            import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
            def wrap_frozen_graph(gd, inputs, outputs):
 | 
			
		||||
                x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), [])  # wrapped
 | 
			
		||||
                x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), [])  # wrapped
 | 
			
		||||
                ge = x.graph.as_graph_element
 | 
			
		||||
                return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
 | 
			
		||||
 | 
			
		||||
@ -198,7 +198,7 @@ class AutoBackend(nn.Module):
 | 
			
		||||
            gd = tf.Graph().as_graph_def()  # TF GraphDef
 | 
			
		||||
            with open(w, 'rb') as f:
 | 
			
		||||
                gd.ParseFromString(f.read())
 | 
			
		||||
            frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
 | 
			
		||||
            frozen_func = wrap_frozen_graph(gd, inputs='x:0', outputs=gd_outputs(gd))
 | 
			
		||||
        elif tflite or edgetpu:  # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
 | 
			
		||||
            try:  # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
 | 
			
		||||
                from tflite_runtime.interpreter import Interpreter, load_delegate
 | 
			
		||||
@ -220,9 +220,9 @@ class AutoBackend(nn.Module):
 | 
			
		||||
            output_details = interpreter.get_output_details()  # outputs
 | 
			
		||||
            # load metadata
 | 
			
		||||
            with contextlib.suppress(zipfile.BadZipFile):
 | 
			
		||||
                with zipfile.ZipFile(w, "r") as model:
 | 
			
		||||
                with zipfile.ZipFile(w, 'r') as model:
 | 
			
		||||
                    meta_file = model.namelist()[0]
 | 
			
		||||
                    meta = ast.literal_eval(model.read(meta_file).decode("utf-8"))
 | 
			
		||||
                    meta = ast.literal_eval(model.read(meta_file).decode('utf-8'))
 | 
			
		||||
                    stride, names = int(meta['stride']), meta['names']
 | 
			
		||||
        elif tfjs:  # TF.js
 | 
			
		||||
            raise NotImplementedError('YOLOv8 TF.js inference is not supported')
 | 
			
		||||
@ -251,8 +251,8 @@ class AutoBackend(nn.Module):
 | 
			
		||||
        else:
 | 
			
		||||
            from ultralytics.yolo.engine.exporter import EXPORT_FORMATS_TABLE
 | 
			
		||||
            raise TypeError(f"model='{w}' is not a supported model format. "
 | 
			
		||||
                            "See https://docs.ultralytics.com/tasks/detection/#export for help."
 | 
			
		||||
                            f"\n\n{EXPORT_FORMATS_TABLE}")
 | 
			
		||||
                            'See https://docs.ultralytics.com/tasks/detection/#export for help.'
 | 
			
		||||
                            f'\n\n{EXPORT_FORMATS_TABLE}')
 | 
			
		||||
 | 
			
		||||
        # Load external metadata YAML
 | 
			
		||||
        if xml or saved_model or paddle:
 | 
			
		||||
@ -410,5 +410,5 @@ class AutoBackend(nn.Module):
 | 
			
		||||
        url = urlparse(p)  # if url may be Triton inference server
 | 
			
		||||
        types = [s in Path(p).name for s in sf]
 | 
			
		||||
        types[8] &= not types[9]  # tflite &= not edgetpu
 | 
			
		||||
        triton = not any(types) and all([any(s in url.scheme for s in ["http", "grpc"]), url.netloc])
 | 
			
		||||
        triton = not any(types) and all([any(s in url.scheme for s in ['http', 'grpc']), url.netloc])
 | 
			
		||||
        return types + [triton]
 | 
			
		||||
 | 
			
		||||
@ -99,7 +99,7 @@ class AutoShape(nn.Module):
 | 
			
		||||
                shape1.append([y * g for y in s])
 | 
			
		||||
                ims[i] = im if im.data.contiguous else np.ascontiguousarray(im)  # update
 | 
			
		||||
            shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] if self.pt else size  # inf shape
 | 
			
		||||
            x = [LetterBox(shape1, auto=False)(image=im)["img"] for im in ims]  # pad
 | 
			
		||||
            x = [LetterBox(shape1, auto=False)(image=im)['img'] for im in ims]  # pad
 | 
			
		||||
            x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2)))  # stack and BHWC to BCHW
 | 
			
		||||
            x = torch.from_numpy(x).to(p.device).type_as(p) / 255  # uint8 to fp16/32
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -160,7 +160,7 @@ class BaseModel(nn.Module):
 | 
			
		||||
            weights (str): The weights to load into the model.
 | 
			
		||||
        """
 | 
			
		||||
        # Force all tasks to implement this function
 | 
			
		||||
        raise NotImplementedError("This function needs to be implemented by derived classes!")
 | 
			
		||||
        raise NotImplementedError('This function needs to be implemented by derived classes!')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DetectionModel(BaseModel):
 | 
			
		||||
@ -249,7 +249,7 @@ class SegmentationModel(DetectionModel):
 | 
			
		||||
        super().__init__(cfg, ch, nc, verbose)
 | 
			
		||||
 | 
			
		||||
    def _forward_augment(self, x):
 | 
			
		||||
        raise NotImplementedError("WARNING ⚠️ SegmentationModel has not supported augment inference yet!")
 | 
			
		||||
        raise NotImplementedError('WARNING ⚠️ SegmentationModel has not supported augment inference yet!')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ClassificationModel(BaseModel):
 | 
			
		||||
@ -292,7 +292,7 @@ class ClassificationModel(BaseModel):
 | 
			
		||||
        self.info()
 | 
			
		||||
 | 
			
		||||
    def load(self, weights):
 | 
			
		||||
        model = weights["model"] if isinstance(weights, dict) else weights  # torchvision models are not dicts
 | 
			
		||||
        model = weights['model'] if isinstance(weights, dict) else weights  # torchvision models are not dicts
 | 
			
		||||
        csd = model.float().state_dict()
 | 
			
		||||
        csd = intersect_dicts(csd, self.state_dict())  # intersect
 | 
			
		||||
        self.load_state_dict(csd, strict=False)  # load
 | 
			
		||||
@ -341,10 +341,10 @@ def torch_safe_load(weight):
 | 
			
		||||
        return torch.load(file, map_location='cpu')  # load
 | 
			
		||||
    except ModuleNotFoundError as e:
 | 
			
		||||
        if e.name == 'omegaconf':  # e.name is missing module name
 | 
			
		||||
            LOGGER.warning(f"WARNING ⚠️ {weight} requires {e.name}, which is not in ultralytics requirements."
 | 
			
		||||
                           f"\nAutoInstall will run now for {e.name} but this feature will be removed in the future."
 | 
			
		||||
                           f"\nRecommend fixes are to train a new model using updated ultralytics package or to "
 | 
			
		||||
                           f"download updated models from https://github.com/ultralytics/assets/releases/tag/v0.0.0")
 | 
			
		||||
            LOGGER.warning(f'WARNING ⚠️ {weight} requires {e.name}, which is not in ultralytics requirements.'
 | 
			
		||||
                           f'\nAutoInstall will run now for {e.name} but this feature will be removed in the future.'
 | 
			
		||||
                           f'\nRecommend fixes are to train a new model using updated ultralytics package or to '
 | 
			
		||||
                           f'download updated models from https://github.com/ultralytics/assets/releases/tag/v0.0.0')
 | 
			
		||||
        if e.name != 'models':
 | 
			
		||||
            check_requirements(e.name)  # install missing module
 | 
			
		||||
        return torch.load(file, map_location='cpu')  # load
 | 
			
		||||
@ -489,13 +489,13 @@ def guess_model_task(model):
 | 
			
		||||
 | 
			
		||||
    def cfg2task(cfg):
 | 
			
		||||
        # Guess from YAML dictionary
 | 
			
		||||
        m = cfg["head"][-1][-2].lower()  # output module name
 | 
			
		||||
        if m in ["classify", "classifier", "cls", "fc"]:
 | 
			
		||||
            return "classify"
 | 
			
		||||
        if m in ["detect"]:
 | 
			
		||||
            return "detect"
 | 
			
		||||
        if m in ["segment"]:
 | 
			
		||||
            return "segment"
 | 
			
		||||
        m = cfg['head'][-1][-2].lower()  # output module name
 | 
			
		||||
        if m in ['classify', 'classifier', 'cls', 'fc']:
 | 
			
		||||
            return 'classify'
 | 
			
		||||
        if m in ['detect']:
 | 
			
		||||
            return 'detect'
 | 
			
		||||
        if m in ['segment']:
 | 
			
		||||
            return 'segment'
 | 
			
		||||
 | 
			
		||||
    # Guess from model cfg
 | 
			
		||||
    if isinstance(model, dict):
 | 
			
		||||
@ -513,22 +513,22 @@ def guess_model_task(model):
 | 
			
		||||
 | 
			
		||||
        for m in model.modules():
 | 
			
		||||
            if isinstance(m, Detect):
 | 
			
		||||
                return "detect"
 | 
			
		||||
                return 'detect'
 | 
			
		||||
            elif isinstance(m, Segment):
 | 
			
		||||
                return "segment"
 | 
			
		||||
                return 'segment'
 | 
			
		||||
            elif isinstance(m, Classify):
 | 
			
		||||
                return "classify"
 | 
			
		||||
                return 'classify'
 | 
			
		||||
 | 
			
		||||
    # Guess from model filename
 | 
			
		||||
    if isinstance(model, (str, Path)):
 | 
			
		||||
        model = Path(model).stem
 | 
			
		||||
        if '-seg' in model:
 | 
			
		||||
            return "segment"
 | 
			
		||||
            return 'segment'
 | 
			
		||||
        elif '-cls' in model:
 | 
			
		||||
            return "classify"
 | 
			
		||||
            return 'classify'
 | 
			
		||||
        else:
 | 
			
		||||
            return "detect"
 | 
			
		||||
            return 'detect'
 | 
			
		||||
 | 
			
		||||
    # Unable to determine task from model
 | 
			
		||||
    raise SyntaxError("YOLO is unable to automatically guess model task. Explicitly define task for your model, "
 | 
			
		||||
    raise SyntaxError('YOLO is unable to automatically guess model task. Explicitly define task for your model, '
 | 
			
		||||
                      "i.e. 'task=detect', 'task=segment' or 'task=classify'.")
 | 
			
		||||
 | 
			
		||||
@ -4,14 +4,14 @@ from ultralytics.tracker import BOTSORT, BYTETracker
 | 
			
		||||
from ultralytics.yolo.utils import IterableSimpleNamespace, yaml_load
 | 
			
		||||
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
 | 
			
		||||
 | 
			
		||||
TRACKER_MAP = {"bytetrack": BYTETracker, "botsort": BOTSORT}
 | 
			
		||||
TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT}
 | 
			
		||||
check_requirements('lap')  # for linear_assignment
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def on_predict_start(predictor):
 | 
			
		||||
    tracker = check_yaml(predictor.args.tracker)
 | 
			
		||||
    cfg = IterableSimpleNamespace(**yaml_load(tracker))
 | 
			
		||||
    assert cfg.tracker_type in ["bytetrack", "botsort"], \
 | 
			
		||||
    assert cfg.tracker_type in ['bytetrack', 'botsort'], \
 | 
			
		||||
            f"Only support 'bytetrack' and 'botsort' for now, but got '{cfg.tracker_type}'"
 | 
			
		||||
    trackers = []
 | 
			
		||||
    for _ in range(predictor.dataset.bs):
 | 
			
		||||
@ -38,5 +38,5 @@ def on_predict_postprocess_end(predictor):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def register_tracker(model):
 | 
			
		||||
    model.add_callback("on_predict_start", on_predict_start)
 | 
			
		||||
    model.add_callback("on_predict_postprocess_end", on_predict_postprocess_end)
 | 
			
		||||
    model.add_callback('on_predict_start', on_predict_start)
 | 
			
		||||
    model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end)
 | 
			
		||||
 | 
			
		||||
@ -153,7 +153,7 @@ class STrack(BaseTrack):
 | 
			
		||||
        return ret
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
 | 
			
		||||
        return f'OT_{self.track_id}_({self.start_frame}-{self.end_frame})'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BYTETracker:
 | 
			
		||||
@ -206,7 +206,7 @@ class BYTETracker:
 | 
			
		||||
        strack_pool = self.joint_stracks(tracked_stracks, self.lost_stracks)
 | 
			
		||||
        # Predict the current location with KF
 | 
			
		||||
        self.multi_predict(strack_pool)
 | 
			
		||||
        if hasattr(self, "gmc"):
 | 
			
		||||
        if hasattr(self, 'gmc'):
 | 
			
		||||
            warp = self.gmc.apply(img, dets)
 | 
			
		||||
            STrack.multi_gmc(strack_pool, warp)
 | 
			
		||||
            STrack.multi_gmc(unconfirmed, warp)
 | 
			
		||||
 | 
			
		||||
@ -50,14 +50,14 @@ class GMC:
 | 
			
		||||
                seqName = seqName[:-6]
 | 
			
		||||
            elif '-DPM' in seqName or '-SDP' in seqName:
 | 
			
		||||
                seqName = seqName[:-4]
 | 
			
		||||
            self.gmcFile = open(f"{filePath}/GMC-{seqName}.txt")
 | 
			
		||||
            self.gmcFile = open(f'{filePath}/GMC-{seqName}.txt')
 | 
			
		||||
 | 
			
		||||
            if self.gmcFile is None:
 | 
			
		||||
                raise ValueError(f"Error: Unable to open GMC file in directory:{filePath}")
 | 
			
		||||
                raise ValueError(f'Error: Unable to open GMC file in directory:{filePath}')
 | 
			
		||||
        elif self.method in ['none', 'None']:
 | 
			
		||||
            self.method = 'none'
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(f"Error: Unknown CMC method:{method}")
 | 
			
		||||
            raise ValueError(f'Error: Unknown CMC method:{method}')
 | 
			
		||||
 | 
			
		||||
        self.prevFrame = None
 | 
			
		||||
        self.prevKeyPoints = None
 | 
			
		||||
@ -302,7 +302,7 @@ class GMC:
 | 
			
		||||
 | 
			
		||||
    def applyFile(self, raw_frame, detections=None):
 | 
			
		||||
        line = self.gmcFile.readline()
 | 
			
		||||
        tokens = line.split("\t")
 | 
			
		||||
        tokens = line.split('\t')
 | 
			
		||||
        H = np.eye(2, 3, dtype=np.float_)
 | 
			
		||||
        H[0, 0] = float(tokens[1])
 | 
			
		||||
        H[0, 1] = float(tokens[2])
 | 
			
		||||
 | 
			
		||||
@ -2,4 +2,4 @@
 | 
			
		||||
 | 
			
		||||
from . import v8
 | 
			
		||||
 | 
			
		||||
__all__ = ["v8"]
 | 
			
		||||
__all__ = ['v8']
 | 
			
		||||
 | 
			
		||||
@ -142,8 +142,8 @@ def check_cfg_mismatch(base: Dict, custom: Dict, e=None):
 | 
			
		||||
        string = ''
 | 
			
		||||
        for x in mismatched:
 | 
			
		||||
            matches = get_close_matches(x, base)  # key list
 | 
			
		||||
            matches = [f"{k}={DEFAULT_CFG_DICT[k]}" if DEFAULT_CFG_DICT.get(k) is not None else k for k in matches]
 | 
			
		||||
            match_str = f"Similar arguments are i.e. {matches}." if matches else ''
 | 
			
		||||
            matches = [f'{k}={DEFAULT_CFG_DICT[k]}' if DEFAULT_CFG_DICT.get(k) is not None else k for k in matches]
 | 
			
		||||
            match_str = f'Similar arguments are i.e. {matches}.' if matches else ''
 | 
			
		||||
            string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
 | 
			
		||||
        raise SyntaxError(string + CLI_HELP_MSG) from e
 | 
			
		||||
 | 
			
		||||
@ -163,10 +163,10 @@ def merge_equals_args(args: List[str]) -> List[str]:
 | 
			
		||||
    new_args = []
 | 
			
		||||
    for i, arg in enumerate(args):
 | 
			
		||||
        if arg == '=' and 0 < i < len(args) - 1:  # merge ['arg', '=', 'val']
 | 
			
		||||
            new_args[-1] += f"={args[i + 1]}"
 | 
			
		||||
            new_args[-1] += f'={args[i + 1]}'
 | 
			
		||||
            del args[i + 1]
 | 
			
		||||
        elif arg.endswith('=') and i < len(args) - 1 and '=' not in args[i + 1]:  # merge ['arg=', 'val']
 | 
			
		||||
            new_args.append(f"{arg}{args[i + 1]}")
 | 
			
		||||
            new_args.append(f'{arg}{args[i + 1]}')
 | 
			
		||||
            del args[i + 1]
 | 
			
		||||
        elif arg.startswith('=') and i > 0:  # merge ['arg', '=val']
 | 
			
		||||
            new_args[-1] += arg
 | 
			
		||||
@ -223,7 +223,7 @@ def entrypoint(debug=''):
 | 
			
		||||
                k, v = a.split('=', 1)  # split on first '=' sign
 | 
			
		||||
                assert v, f"missing '{k}' value"
 | 
			
		||||
                if k == 'cfg':  # custom.yaml passed
 | 
			
		||||
                    LOGGER.info(f"Overriding {DEFAULT_CFG_PATH} with {v}")
 | 
			
		||||
                    LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}')
 | 
			
		||||
                    overrides = {k: val for k, val in yaml_load(v).items() if k != 'cfg'}
 | 
			
		||||
                else:
 | 
			
		||||
                    if v.lower() == 'none':
 | 
			
		||||
@ -237,7 +237,7 @@ def entrypoint(debug=''):
 | 
			
		||||
                            v = eval(v)
 | 
			
		||||
                    overrides[k] = v
 | 
			
		||||
            except (NameError, SyntaxError, ValueError, AssertionError) as e:
 | 
			
		||||
                check_cfg_mismatch(full_args_dict, {a: ""}, e)
 | 
			
		||||
                check_cfg_mismatch(full_args_dict, {a: ''}, e)
 | 
			
		||||
 | 
			
		||||
        elif a in tasks:
 | 
			
		||||
            overrides['task'] = a
 | 
			
		||||
@ -252,7 +252,7 @@ def entrypoint(debug=''):
 | 
			
		||||
            raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
 | 
			
		||||
                              f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}")
 | 
			
		||||
        else:
 | 
			
		||||
            check_cfg_mismatch(full_args_dict, {a: ""})
 | 
			
		||||
            check_cfg_mismatch(full_args_dict, {a: ''})
 | 
			
		||||
 | 
			
		||||
    # Defaults
 | 
			
		||||
    task2model = dict(detect='yolov8n.pt', segment='yolov8n-seg.pt', classify='yolov8n-cls.pt')
 | 
			
		||||
@ -287,8 +287,8 @@ def entrypoint(debug=''):
 | 
			
		||||
    task = model.task
 | 
			
		||||
    overrides['task'] = task
 | 
			
		||||
    if mode in {'predict', 'track'} and 'source' not in overrides:
 | 
			
		||||
        overrides['source'] = DEFAULT_CFG.source or ROOT / "assets" if (ROOT / "assets").exists() \
 | 
			
		||||
            else "https://ultralytics.com/images/bus.jpg"
 | 
			
		||||
        overrides['source'] = DEFAULT_CFG.source or ROOT / 'assets' if (ROOT / 'assets').exists() \
 | 
			
		||||
            else 'https://ultralytics.com/images/bus.jpg'
 | 
			
		||||
        LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
 | 
			
		||||
    elif mode in ('train', 'val'):
 | 
			
		||||
        if 'data' not in overrides:
 | 
			
		||||
@ -308,7 +308,7 @@ def entrypoint(debug=''):
 | 
			
		||||
def copy_default_cfg():
 | 
			
		||||
    new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
 | 
			
		||||
    shutil.copy2(DEFAULT_CFG_PATH, new_file)
 | 
			
		||||
    LOGGER.info(f"{DEFAULT_CFG_PATH} copied to {new_file}\n"
 | 
			
		||||
    LOGGER.info(f'{DEFAULT_CFG_PATH} copied to {new_file}\n'
 | 
			
		||||
                f"Example YOLO command with this new custom cfg:\n    yolo cfg='{new_file}' imgsz=320 batch=8")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -6,11 +6,11 @@ from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
 | 
			
		||||
from .dataset_wrappers import MixAndRectDataset
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    "BaseDataset",
 | 
			
		||||
    "ClassificationDataset",
 | 
			
		||||
    "MixAndRectDataset",
 | 
			
		||||
    "SemanticDataset",
 | 
			
		||||
    "YOLODataset",
 | 
			
		||||
    "build_classification_dataloader",
 | 
			
		||||
    "build_dataloader",
 | 
			
		||||
    "load_inference_source",]
 | 
			
		||||
    'BaseDataset',
 | 
			
		||||
    'ClassificationDataset',
 | 
			
		||||
    'MixAndRectDataset',
 | 
			
		||||
    'SemanticDataset',
 | 
			
		||||
    'YOLODataset',
 | 
			
		||||
    'build_classification_dataloader',
 | 
			
		||||
    'build_dataloader',
 | 
			
		||||
    'load_inference_source',]
 | 
			
		||||
 | 
			
		||||
@ -55,11 +55,11 @@ class Compose:
 | 
			
		||||
        return self.transforms
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        format_string = f"{self.__class__.__name__}("
 | 
			
		||||
        format_string = f'{self.__class__.__name__}('
 | 
			
		||||
        for t in self.transforms:
 | 
			
		||||
            format_string += "\n"
 | 
			
		||||
            format_string += f"    {t}"
 | 
			
		||||
        format_string += "\n)"
 | 
			
		||||
            format_string += '\n'
 | 
			
		||||
            format_string += f'    {t}'
 | 
			
		||||
        format_string += '\n)'
 | 
			
		||||
        return format_string
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -86,11 +86,11 @@ class BaseMixTransform:
 | 
			
		||||
        if self.pre_transform is not None:
 | 
			
		||||
            for i, data in enumerate(mix_labels):
 | 
			
		||||
                mix_labels[i] = self.pre_transform(data)
 | 
			
		||||
        labels["mix_labels"] = mix_labels
 | 
			
		||||
        labels['mix_labels'] = mix_labels
 | 
			
		||||
 | 
			
		||||
        # Mosaic or MixUp
 | 
			
		||||
        labels = self._mix_transform(labels)
 | 
			
		||||
        labels.pop("mix_labels", None)
 | 
			
		||||
        labels.pop('mix_labels', None)
 | 
			
		||||
        return labels
 | 
			
		||||
 | 
			
		||||
    def _mix_transform(self, labels):
 | 
			
		||||
@ -109,7 +109,7 @@ class Mosaic(BaseMixTransform):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, dataset, imgsz=640, p=1.0, border=(0, 0)):
 | 
			
		||||
        assert 0 <= p <= 1.0, "The probability should be in range [0, 1]. " f"got {p}."
 | 
			
		||||
        assert 0 <= p <= 1.0, 'The probability should be in range [0, 1]. ' f'got {p}.'
 | 
			
		||||
        super().__init__(dataset=dataset, p=p)
 | 
			
		||||
        self.dataset = dataset
 | 
			
		||||
        self.imgsz = imgsz
 | 
			
		||||
@ -120,15 +120,15 @@ class Mosaic(BaseMixTransform):
 | 
			
		||||
 | 
			
		||||
    def _mix_transform(self, labels):
 | 
			
		||||
        mosaic_labels = []
 | 
			
		||||
        assert labels.get("rect_shape", None) is None, "rect and mosaic is exclusive."
 | 
			
		||||
        assert len(labels.get("mix_labels", [])) > 0, "There are no other images for mosaic augment."
 | 
			
		||||
        assert labels.get('rect_shape', None) is None, 'rect and mosaic is exclusive.'
 | 
			
		||||
        assert len(labels.get('mix_labels', [])) > 0, 'There are no other images for mosaic augment.'
 | 
			
		||||
        s = self.imgsz
 | 
			
		||||
        yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border)  # mosaic center x, y
 | 
			
		||||
        for i in range(4):
 | 
			
		||||
            labels_patch = (labels if i == 0 else labels["mix_labels"][i - 1]).copy()
 | 
			
		||||
            labels_patch = (labels if i == 0 else labels['mix_labels'][i - 1]).copy()
 | 
			
		||||
            # Load image
 | 
			
		||||
            img = labels_patch["img"]
 | 
			
		||||
            h, w = labels_patch.pop("resized_shape")
 | 
			
		||||
            img = labels_patch['img']
 | 
			
		||||
            h, w = labels_patch.pop('resized_shape')
 | 
			
		||||
 | 
			
		||||
            # place img in img4
 | 
			
		||||
            if i == 0:  # top left
 | 
			
		||||
@ -152,15 +152,15 @@ class Mosaic(BaseMixTransform):
 | 
			
		||||
            labels_patch = self._update_labels(labels_patch, padw, padh)
 | 
			
		||||
            mosaic_labels.append(labels_patch)
 | 
			
		||||
        final_labels = self._cat_labels(mosaic_labels)
 | 
			
		||||
        final_labels["img"] = img4
 | 
			
		||||
        final_labels['img'] = img4
 | 
			
		||||
        return final_labels
 | 
			
		||||
 | 
			
		||||
    def _update_labels(self, labels, padw, padh):
 | 
			
		||||
        """Update labels"""
 | 
			
		||||
        nh, nw = labels["img"].shape[:2]
 | 
			
		||||
        labels["instances"].convert_bbox(format="xyxy")
 | 
			
		||||
        labels["instances"].denormalize(nw, nh)
 | 
			
		||||
        labels["instances"].add_padding(padw, padh)
 | 
			
		||||
        nh, nw = labels['img'].shape[:2]
 | 
			
		||||
        labels['instances'].convert_bbox(format='xyxy')
 | 
			
		||||
        labels['instances'].denormalize(nw, nh)
 | 
			
		||||
        labels['instances'].add_padding(padw, padh)
 | 
			
		||||
        return labels
 | 
			
		||||
 | 
			
		||||
    def _cat_labels(self, mosaic_labels):
 | 
			
		||||
@ -169,16 +169,16 @@ class Mosaic(BaseMixTransform):
 | 
			
		||||
        cls = []
 | 
			
		||||
        instances = []
 | 
			
		||||
        for labels in mosaic_labels:
 | 
			
		||||
            cls.append(labels["cls"])
 | 
			
		||||
            instances.append(labels["instances"])
 | 
			
		||||
            cls.append(labels['cls'])
 | 
			
		||||
            instances.append(labels['instances'])
 | 
			
		||||
        final_labels = {
 | 
			
		||||
            "im_file": mosaic_labels[0]["im_file"],
 | 
			
		||||
            "ori_shape": mosaic_labels[0]["ori_shape"],
 | 
			
		||||
            "resized_shape": (self.imgsz * 2, self.imgsz * 2),
 | 
			
		||||
            "cls": np.concatenate(cls, 0),
 | 
			
		||||
            "instances": Instances.concatenate(instances, axis=0),
 | 
			
		||||
            "mosaic_border": self.border}
 | 
			
		||||
        final_labels["instances"].clip(self.imgsz * 2, self.imgsz * 2)
 | 
			
		||||
            'im_file': mosaic_labels[0]['im_file'],
 | 
			
		||||
            'ori_shape': mosaic_labels[0]['ori_shape'],
 | 
			
		||||
            'resized_shape': (self.imgsz * 2, self.imgsz * 2),
 | 
			
		||||
            'cls': np.concatenate(cls, 0),
 | 
			
		||||
            'instances': Instances.concatenate(instances, axis=0),
 | 
			
		||||
            'mosaic_border': self.border}
 | 
			
		||||
        final_labels['instances'].clip(self.imgsz * 2, self.imgsz * 2)
 | 
			
		||||
        return final_labels
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -193,10 +193,10 @@ class MixUp(BaseMixTransform):
 | 
			
		||||
    def _mix_transform(self, labels):
 | 
			
		||||
        # Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
 | 
			
		||||
        r = np.random.beta(32.0, 32.0)  # mixup ratio, alpha=beta=32.0
 | 
			
		||||
        labels2 = labels["mix_labels"][0]
 | 
			
		||||
        labels["img"] = (labels["img"] * r + labels2["img"] * (1 - r)).astype(np.uint8)
 | 
			
		||||
        labels["instances"] = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0)
 | 
			
		||||
        labels["cls"] = np.concatenate([labels["cls"], labels2["cls"]], 0)
 | 
			
		||||
        labels2 = labels['mix_labels'][0]
 | 
			
		||||
        labels['img'] = (labels['img'] * r + labels2['img'] * (1 - r)).astype(np.uint8)
 | 
			
		||||
        labels['instances'] = Instances.concatenate([labels['instances'], labels2['instances']], axis=0)
 | 
			
		||||
        labels['cls'] = np.concatenate([labels['cls'], labels2['cls']], 0)
 | 
			
		||||
        return labels
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -338,18 +338,18 @@ class RandomPerspective:
 | 
			
		||||
        Args:
 | 
			
		||||
            labels(Dict): a dict of `bboxes`, `segments`, `keypoints`.
 | 
			
		||||
        """
 | 
			
		||||
        if self.pre_transform and "mosaic_border" not in labels:
 | 
			
		||||
        if self.pre_transform and 'mosaic_border' not in labels:
 | 
			
		||||
            labels = self.pre_transform(labels)
 | 
			
		||||
            labels.pop("ratio_pad")  # do not need ratio pad
 | 
			
		||||
            labels.pop('ratio_pad')  # do not need ratio pad
 | 
			
		||||
 | 
			
		||||
        img = labels["img"]
 | 
			
		||||
        cls = labels["cls"]
 | 
			
		||||
        instances = labels.pop("instances")
 | 
			
		||||
        img = labels['img']
 | 
			
		||||
        cls = labels['cls']
 | 
			
		||||
        instances = labels.pop('instances')
 | 
			
		||||
        # make sure the coord formats are right
 | 
			
		||||
        instances.convert_bbox(format="xyxy")
 | 
			
		||||
        instances.convert_bbox(format='xyxy')
 | 
			
		||||
        instances.denormalize(*img.shape[:2][::-1])
 | 
			
		||||
 | 
			
		||||
        border = labels.pop("mosaic_border", self.border)
 | 
			
		||||
        border = labels.pop('mosaic_border', self.border)
 | 
			
		||||
        self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2  # w, h
 | 
			
		||||
        # M is affine matrix
 | 
			
		||||
        # scale for func:`box_candidates`
 | 
			
		||||
@ -365,7 +365,7 @@ class RandomPerspective:
 | 
			
		||||
 | 
			
		||||
        if keypoints is not None:
 | 
			
		||||
            keypoints = self.apply_keypoints(keypoints, M)
 | 
			
		||||
        new_instances = Instances(bboxes, segments, keypoints, bbox_format="xyxy", normalized=False)
 | 
			
		||||
        new_instances = Instances(bboxes, segments, keypoints, bbox_format='xyxy', normalized=False)
 | 
			
		||||
        # clip
 | 
			
		||||
        new_instances.clip(*self.size)
 | 
			
		||||
 | 
			
		||||
@ -375,10 +375,10 @@ class RandomPerspective:
 | 
			
		||||
        i = self.box_candidates(box1=instances.bboxes.T,
 | 
			
		||||
                                box2=new_instances.bboxes.T,
 | 
			
		||||
                                area_thr=0.01 if len(segments) else 0.10)
 | 
			
		||||
        labels["instances"] = new_instances[i]
 | 
			
		||||
        labels["cls"] = cls[i]
 | 
			
		||||
        labels["img"] = img
 | 
			
		||||
        labels["resized_shape"] = img.shape[:2]
 | 
			
		||||
        labels['instances'] = new_instances[i]
 | 
			
		||||
        labels['cls'] = cls[i]
 | 
			
		||||
        labels['img'] = img
 | 
			
		||||
        labels['resized_shape'] = img.shape[:2]
 | 
			
		||||
        return labels
 | 
			
		||||
 | 
			
		||||
    def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16):  # box1(4,n), box2(4,n)
 | 
			
		||||
@ -397,7 +397,7 @@ class RandomHSV:
 | 
			
		||||
        self.vgain = vgain
 | 
			
		||||
 | 
			
		||||
    def __call__(self, labels):
 | 
			
		||||
        img = labels["img"]
 | 
			
		||||
        img = labels['img']
 | 
			
		||||
        if self.hgain or self.sgain or self.vgain:
 | 
			
		||||
            r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1  # random gains
 | 
			
		||||
            hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
 | 
			
		||||
@ -415,30 +415,30 @@ class RandomHSV:
 | 
			
		||||
 | 
			
		||||
class RandomFlip:
 | 
			
		||||
 | 
			
		||||
    def __init__(self, p=0.5, direction="horizontal") -> None:
 | 
			
		||||
        assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}"
 | 
			
		||||
    def __init__(self, p=0.5, direction='horizontal') -> None:
 | 
			
		||||
        assert direction in ['horizontal', 'vertical'], f'Support direction `horizontal` or `vertical`, got {direction}'
 | 
			
		||||
        assert 0 <= p <= 1.0
 | 
			
		||||
 | 
			
		||||
        self.p = p
 | 
			
		||||
        self.direction = direction
 | 
			
		||||
 | 
			
		||||
    def __call__(self, labels):
 | 
			
		||||
        img = labels["img"]
 | 
			
		||||
        instances = labels.pop("instances")
 | 
			
		||||
        instances.convert_bbox(format="xywh")
 | 
			
		||||
        img = labels['img']
 | 
			
		||||
        instances = labels.pop('instances')
 | 
			
		||||
        instances.convert_bbox(format='xywh')
 | 
			
		||||
        h, w = img.shape[:2]
 | 
			
		||||
        h = 1 if instances.normalized else h
 | 
			
		||||
        w = 1 if instances.normalized else w
 | 
			
		||||
 | 
			
		||||
        # Flip up-down
 | 
			
		||||
        if self.direction == "vertical" and random.random() < self.p:
 | 
			
		||||
        if self.direction == 'vertical' and random.random() < self.p:
 | 
			
		||||
            img = np.flipud(img)
 | 
			
		||||
            instances.flipud(h)
 | 
			
		||||
        if self.direction == "horizontal" and random.random() < self.p:
 | 
			
		||||
        if self.direction == 'horizontal' and random.random() < self.p:
 | 
			
		||||
            img = np.fliplr(img)
 | 
			
		||||
            instances.fliplr(w)
 | 
			
		||||
        labels["img"] = np.ascontiguousarray(img)
 | 
			
		||||
        labels["instances"] = instances
 | 
			
		||||
        labels['img'] = np.ascontiguousarray(img)
 | 
			
		||||
        labels['instances'] = instances
 | 
			
		||||
        return labels
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -455,9 +455,9 @@ class LetterBox:
 | 
			
		||||
    def __call__(self, labels=None, image=None):
 | 
			
		||||
        if labels is None:
 | 
			
		||||
            labels = {}
 | 
			
		||||
        img = labels.get("img") if image is None else image
 | 
			
		||||
        img = labels.get('img') if image is None else image
 | 
			
		||||
        shape = img.shape[:2]  # current shape [height, width]
 | 
			
		||||
        new_shape = labels.pop("rect_shape", self.new_shape)
 | 
			
		||||
        new_shape = labels.pop('rect_shape', self.new_shape)
 | 
			
		||||
        if isinstance(new_shape, int):
 | 
			
		||||
            new_shape = (new_shape, new_shape)
 | 
			
		||||
 | 
			
		||||
@ -479,8 +479,8 @@ class LetterBox:
 | 
			
		||||
 | 
			
		||||
        dw /= 2  # divide padding into 2 sides
 | 
			
		||||
        dh /= 2
 | 
			
		||||
        if labels.get("ratio_pad"):
 | 
			
		||||
            labels["ratio_pad"] = (labels["ratio_pad"], (dw, dh))  # for evaluation
 | 
			
		||||
        if labels.get('ratio_pad'):
 | 
			
		||||
            labels['ratio_pad'] = (labels['ratio_pad'], (dw, dh))  # for evaluation
 | 
			
		||||
 | 
			
		||||
        if shape[::-1] != new_unpad:  # resize
 | 
			
		||||
            img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
 | 
			
		||||
@ -491,18 +491,18 @@ class LetterBox:
 | 
			
		||||
 | 
			
		||||
        if len(labels):
 | 
			
		||||
            labels = self._update_labels(labels, ratio, dw, dh)
 | 
			
		||||
            labels["img"] = img
 | 
			
		||||
            labels["resized_shape"] = new_shape
 | 
			
		||||
            labels['img'] = img
 | 
			
		||||
            labels['resized_shape'] = new_shape
 | 
			
		||||
            return labels
 | 
			
		||||
        else:
 | 
			
		||||
            return img
 | 
			
		||||
 | 
			
		||||
    def _update_labels(self, labels, ratio, padw, padh):
 | 
			
		||||
        """Update labels"""
 | 
			
		||||
        labels["instances"].convert_bbox(format="xyxy")
 | 
			
		||||
        labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
 | 
			
		||||
        labels["instances"].scale(*ratio)
 | 
			
		||||
        labels["instances"].add_padding(padw, padh)
 | 
			
		||||
        labels['instances'].convert_bbox(format='xyxy')
 | 
			
		||||
        labels['instances'].denormalize(*labels['img'].shape[:2][::-1])
 | 
			
		||||
        labels['instances'].scale(*ratio)
 | 
			
		||||
        labels['instances'].add_padding(padw, padh)
 | 
			
		||||
        return labels
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -513,11 +513,11 @@ class CopyPaste:
 | 
			
		||||
 | 
			
		||||
    def __call__(self, labels):
 | 
			
		||||
        # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
 | 
			
		||||
        im = labels["img"]
 | 
			
		||||
        cls = labels["cls"]
 | 
			
		||||
        im = labels['img']
 | 
			
		||||
        cls = labels['cls']
 | 
			
		||||
        h, w = im.shape[:2]
 | 
			
		||||
        instances = labels.pop("instances")
 | 
			
		||||
        instances.convert_bbox(format="xyxy")
 | 
			
		||||
        instances = labels.pop('instances')
 | 
			
		||||
        instances.convert_bbox(format='xyxy')
 | 
			
		||||
        instances.denormalize(w, h)
 | 
			
		||||
        if self.p and len(instances.segments):
 | 
			
		||||
            n = len(instances)
 | 
			
		||||
@ -540,9 +540,9 @@ class CopyPaste:
 | 
			
		||||
            i = cv2.flip(im_new, 1).astype(bool)
 | 
			
		||||
            im[i] = result[i]  # cv2.imwrite('debug.jpg', im)  # debug
 | 
			
		||||
 | 
			
		||||
        labels["img"] = im
 | 
			
		||||
        labels["cls"] = cls
 | 
			
		||||
        labels["instances"] = instances
 | 
			
		||||
        labels['img'] = im
 | 
			
		||||
        labels['cls'] = cls
 | 
			
		||||
        labels['instances'] = instances
 | 
			
		||||
        return labels
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -551,11 +551,11 @@ class Albumentations:
 | 
			
		||||
    def __init__(self, p=1.0):
 | 
			
		||||
        self.p = p
 | 
			
		||||
        self.transform = None
 | 
			
		||||
        prefix = colorstr("albumentations: ")
 | 
			
		||||
        prefix = colorstr('albumentations: ')
 | 
			
		||||
        try:
 | 
			
		||||
            import albumentations as A
 | 
			
		||||
 | 
			
		||||
            check_version(A.__version__, "1.0.3", hard=True)  # version requirement
 | 
			
		||||
            check_version(A.__version__, '1.0.3', hard=True)  # version requirement
 | 
			
		||||
 | 
			
		||||
            T = [
 | 
			
		||||
                A.Blur(p=0.01),
 | 
			
		||||
@ -565,28 +565,28 @@ class Albumentations:
 | 
			
		||||
                A.RandomBrightnessContrast(p=0.0),
 | 
			
		||||
                A.RandomGamma(p=0.0),
 | 
			
		||||
                A.ImageCompression(quality_lower=75, p=0.0),]  # transforms
 | 
			
		||||
            self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))
 | 
			
		||||
            self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
 | 
			
		||||
 | 
			
		||||
            LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
 | 
			
		||||
            LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
 | 
			
		||||
        except ImportError:  # package not installed, skip
 | 
			
		||||
            pass
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            LOGGER.info(f"{prefix}{e}")
 | 
			
		||||
            LOGGER.info(f'{prefix}{e}')
 | 
			
		||||
 | 
			
		||||
    def __call__(self, labels):
 | 
			
		||||
        im = labels["img"]
 | 
			
		||||
        cls = labels["cls"]
 | 
			
		||||
        im = labels['img']
 | 
			
		||||
        cls = labels['cls']
 | 
			
		||||
        if len(cls):
 | 
			
		||||
            labels["instances"].convert_bbox("xywh")
 | 
			
		||||
            labels["instances"].normalize(*im.shape[:2][::-1])
 | 
			
		||||
            bboxes = labels["instances"].bboxes
 | 
			
		||||
            labels['instances'].convert_bbox('xywh')
 | 
			
		||||
            labels['instances'].normalize(*im.shape[:2][::-1])
 | 
			
		||||
            bboxes = labels['instances'].bboxes
 | 
			
		||||
            # TODO: add supports of segments and keypoints
 | 
			
		||||
            if self.transform and random.random() < self.p:
 | 
			
		||||
                new = self.transform(image=im, bboxes=bboxes, class_labels=cls)  # transformed
 | 
			
		||||
                labels["img"] = new["image"]
 | 
			
		||||
                labels["cls"] = np.array(new["class_labels"])
 | 
			
		||||
                bboxes = np.array(new["bboxes"])
 | 
			
		||||
            labels["instances"].update(bboxes=bboxes)
 | 
			
		||||
                labels['img'] = new['image']
 | 
			
		||||
                labels['cls'] = np.array(new['class_labels'])
 | 
			
		||||
                bboxes = np.array(new['bboxes'])
 | 
			
		||||
            labels['instances'].update(bboxes=bboxes)
 | 
			
		||||
        return labels
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -594,7 +594,7 @@ class Albumentations:
 | 
			
		||||
class Format:
 | 
			
		||||
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 bbox_format="xywh",
 | 
			
		||||
                 bbox_format='xywh',
 | 
			
		||||
                 normalize=True,
 | 
			
		||||
                 return_mask=False,
 | 
			
		||||
                 return_keypoint=False,
 | 
			
		||||
@ -610,10 +610,10 @@ class Format:
 | 
			
		||||
        self.batch_idx = batch_idx  # keep the batch indexes
 | 
			
		||||
 | 
			
		||||
    def __call__(self, labels):
 | 
			
		||||
        img = labels.pop("img")
 | 
			
		||||
        img = labels.pop('img')
 | 
			
		||||
        h, w = img.shape[:2]
 | 
			
		||||
        cls = labels.pop("cls")
 | 
			
		||||
        instances = labels.pop("instances")
 | 
			
		||||
        cls = labels.pop('cls')
 | 
			
		||||
        instances = labels.pop('instances')
 | 
			
		||||
        instances.convert_bbox(format=self.bbox_format)
 | 
			
		||||
        instances.denormalize(w, h)
 | 
			
		||||
        nl = len(instances)
 | 
			
		||||
@ -625,17 +625,17 @@ class Format:
 | 
			
		||||
            else:
 | 
			
		||||
                masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio,
 | 
			
		||||
                                    img.shape[1] // self.mask_ratio)
 | 
			
		||||
            labels["masks"] = masks
 | 
			
		||||
            labels['masks'] = masks
 | 
			
		||||
        if self.normalize:
 | 
			
		||||
            instances.normalize(w, h)
 | 
			
		||||
        labels["img"] = self._format_img(img)
 | 
			
		||||
        labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
 | 
			
		||||
        labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
 | 
			
		||||
        labels['img'] = self._format_img(img)
 | 
			
		||||
        labels['cls'] = torch.from_numpy(cls) if nl else torch.zeros(nl)
 | 
			
		||||
        labels['bboxes'] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
 | 
			
		||||
        if self.return_keypoint:
 | 
			
		||||
            labels["keypoints"] = torch.from_numpy(instances.keypoints) if nl else torch.zeros((nl, 17, 2))
 | 
			
		||||
            labels['keypoints'] = torch.from_numpy(instances.keypoints) if nl else torch.zeros((nl, 17, 2))
 | 
			
		||||
        # then we can use collate_fn
 | 
			
		||||
        if self.batch_idx:
 | 
			
		||||
            labels["batch_idx"] = torch.zeros(nl)
 | 
			
		||||
            labels['batch_idx'] = torch.zeros(nl)
 | 
			
		||||
        return labels
 | 
			
		||||
 | 
			
		||||
    def _format_img(self, img):
 | 
			
		||||
@ -676,15 +676,15 @@ def v8_transforms(dataset, imgsz, hyp):
 | 
			
		||||
        MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
 | 
			
		||||
        Albumentations(p=1.0),
 | 
			
		||||
        RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
 | 
			
		||||
        RandomFlip(direction="vertical", p=hyp.flipud),
 | 
			
		||||
        RandomFlip(direction="horizontal", p=hyp.fliplr),])  # transforms
 | 
			
		||||
        RandomFlip(direction='vertical', p=hyp.flipud),
 | 
			
		||||
        RandomFlip(direction='horizontal', p=hyp.fliplr),])  # transforms
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Classification augmentations -----------------------------------------------------------------------------------------
 | 
			
		||||
def classify_transforms(size=224):
 | 
			
		||||
    # Transforms to apply if albumentations not installed
 | 
			
		||||
    if not isinstance(size, int):
 | 
			
		||||
        raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")
 | 
			
		||||
        raise TypeError(f'classify_transforms() size {size} must be integer, not (list, tuple)')
 | 
			
		||||
    # T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
 | 
			
		||||
    return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
 | 
			
		||||
 | 
			
		||||
@ -701,17 +701,17 @@ def classify_albumentations(
 | 
			
		||||
        auto_aug=False,
 | 
			
		||||
):
 | 
			
		||||
    # YOLOv8 classification Albumentations (optional, only used if package is installed)
 | 
			
		||||
    prefix = colorstr("albumentations: ")
 | 
			
		||||
    prefix = colorstr('albumentations: ')
 | 
			
		||||
    try:
 | 
			
		||||
        import albumentations as A
 | 
			
		||||
        from albumentations.pytorch import ToTensorV2
 | 
			
		||||
 | 
			
		||||
        check_version(A.__version__, "1.0.3", hard=True)  # version requirement
 | 
			
		||||
        check_version(A.__version__, '1.0.3', hard=True)  # version requirement
 | 
			
		||||
        if augment:  # Resize and crop
 | 
			
		||||
            T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
 | 
			
		||||
            if auto_aug:
 | 
			
		||||
                # TODO: implement AugMix, AutoAug & RandAug in albumentation
 | 
			
		||||
                LOGGER.info(f"{prefix}auto augmentations are currently not supported")
 | 
			
		||||
                LOGGER.info(f'{prefix}auto augmentations are currently not supported')
 | 
			
		||||
            else:
 | 
			
		||||
                if hflip > 0:
 | 
			
		||||
                    T += [A.HorizontalFlip(p=hflip)]
 | 
			
		||||
@ -723,13 +723,13 @@ def classify_albumentations(
 | 
			
		||||
        else:  # Use fixed crop for eval set (reproducibility)
 | 
			
		||||
            T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
 | 
			
		||||
        T += [A.Normalize(mean=mean, std=std), ToTensorV2()]  # Normalize and convert to Tensor
 | 
			
		||||
        LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
 | 
			
		||||
        LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
 | 
			
		||||
        return A.Compose(T)
 | 
			
		||||
 | 
			
		||||
    except ImportError:  # package not installed, skip
 | 
			
		||||
        pass
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        LOGGER.info(f"{prefix}{e}")
 | 
			
		||||
        LOGGER.info(f'{prefix}{e}')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ClassifyLetterBox:
 | 
			
		||||
 | 
			
		||||
@ -31,7 +31,7 @@ class BaseDataset(Dataset):
 | 
			
		||||
        cache=False,
 | 
			
		||||
        augment=True,
 | 
			
		||||
        hyp=None,
 | 
			
		||||
        prefix="",
 | 
			
		||||
        prefix='',
 | 
			
		||||
        rect=False,
 | 
			
		||||
        batch_size=None,
 | 
			
		||||
        stride=32,
 | 
			
		||||
@ -63,7 +63,7 @@ class BaseDataset(Dataset):
 | 
			
		||||
 | 
			
		||||
        # cache stuff
 | 
			
		||||
        self.ims = [None] * self.ni
 | 
			
		||||
        self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
 | 
			
		||||
        self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
 | 
			
		||||
        if cache:
 | 
			
		||||
            self.cache_images(cache)
 | 
			
		||||
 | 
			
		||||
@ -77,21 +77,21 @@ class BaseDataset(Dataset):
 | 
			
		||||
            for p in img_path if isinstance(img_path, list) else [img_path]:
 | 
			
		||||
                p = Path(p)  # os-agnostic
 | 
			
		||||
                if p.is_dir():  # dir
 | 
			
		||||
                    f += glob.glob(str(p / "**" / "*.*"), recursive=True)
 | 
			
		||||
                    f += glob.glob(str(p / '**' / '*.*'), recursive=True)
 | 
			
		||||
                    # f = list(p.rglob('*.*'))  # pathlib
 | 
			
		||||
                elif p.is_file():  # file
 | 
			
		||||
                    with open(p) as t:
 | 
			
		||||
                        t = t.read().strip().splitlines()
 | 
			
		||||
                        parent = str(p.parent) + os.sep
 | 
			
		||||
                        f += [x.replace("./", parent) if x.startswith("./") else x for x in t]  # local to global path
 | 
			
		||||
                        f += [x.replace('./', parent) if x.startswith('./') else x for x in t]  # local to global path
 | 
			
		||||
                        # f += [p.parent / x.lstrip(os.sep) for x in t]  # local to global path (pathlib)
 | 
			
		||||
                else:
 | 
			
		||||
                    raise FileNotFoundError(f"{self.prefix}{p} does not exist")
 | 
			
		||||
            im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
 | 
			
		||||
                    raise FileNotFoundError(f'{self.prefix}{p} does not exist')
 | 
			
		||||
            im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
 | 
			
		||||
            # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS])  # pathlib
 | 
			
		||||
            assert im_files, f"{self.prefix}No images found"
 | 
			
		||||
            assert im_files, f'{self.prefix}No images found'
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
 | 
			
		||||
            raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e
 | 
			
		||||
        return im_files
 | 
			
		||||
 | 
			
		||||
    def update_labels(self, include_class: Optional[list]):
 | 
			
		||||
@ -99,16 +99,16 @@ class BaseDataset(Dataset):
 | 
			
		||||
        include_class_array = np.array(include_class).reshape(1, -1)
 | 
			
		||||
        for i in range(len(self.labels)):
 | 
			
		||||
            if include_class:
 | 
			
		||||
                cls = self.labels[i]["cls"]
 | 
			
		||||
                bboxes = self.labels[i]["bboxes"]
 | 
			
		||||
                segments = self.labels[i]["segments"]
 | 
			
		||||
                cls = self.labels[i]['cls']
 | 
			
		||||
                bboxes = self.labels[i]['bboxes']
 | 
			
		||||
                segments = self.labels[i]['segments']
 | 
			
		||||
                j = (cls == include_class_array).any(1)
 | 
			
		||||
                self.labels[i]["cls"] = cls[j]
 | 
			
		||||
                self.labels[i]["bboxes"] = bboxes[j]
 | 
			
		||||
                self.labels[i]['cls'] = cls[j]
 | 
			
		||||
                self.labels[i]['bboxes'] = bboxes[j]
 | 
			
		||||
                if segments:
 | 
			
		||||
                    self.labels[i]["segments"] = segments[j]
 | 
			
		||||
                    self.labels[i]['segments'] = segments[j]
 | 
			
		||||
            if self.single_cls:
 | 
			
		||||
                self.labels[i]["cls"][:, 0] = 0
 | 
			
		||||
                self.labels[i]['cls'][:, 0] = 0
 | 
			
		||||
 | 
			
		||||
    def load_image(self, i):
 | 
			
		||||
        # Loads 1 image from dataset index 'i', returns (im, resized hw)
 | 
			
		||||
@ -119,7 +119,7 @@ class BaseDataset(Dataset):
 | 
			
		||||
            else:  # read image
 | 
			
		||||
                im = cv2.imread(f)  # BGR
 | 
			
		||||
                if im is None:
 | 
			
		||||
                    raise FileNotFoundError(f"Image Not Found {f}")
 | 
			
		||||
                    raise FileNotFoundError(f'Image Not Found {f}')
 | 
			
		||||
            h0, w0 = im.shape[:2]  # orig hw
 | 
			
		||||
            r = self.imgsz / max(h0, w0)  # ratio
 | 
			
		||||
            if r != 1:  # if sizes are not equal
 | 
			
		||||
@ -132,17 +132,17 @@ class BaseDataset(Dataset):
 | 
			
		||||
        # cache images to memory or disk
 | 
			
		||||
        gb = 0  # Gigabytes of cached images
 | 
			
		||||
        self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
 | 
			
		||||
        fcn = self.cache_images_to_disk if cache == "disk" else self.load_image
 | 
			
		||||
        fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
 | 
			
		||||
        with ThreadPool(NUM_THREADS) as pool:
 | 
			
		||||
            results = pool.imap(fcn, range(self.ni))
 | 
			
		||||
            pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
 | 
			
		||||
            for i, x in pbar:
 | 
			
		||||
                if cache == "disk":
 | 
			
		||||
                if cache == 'disk':
 | 
			
		||||
                    gb += self.npy_files[i].stat().st_size
 | 
			
		||||
                else:  # 'ram'
 | 
			
		||||
                    self.ims[i], self.im_hw0[i], self.im_hw[i] = x  # im, hw_orig, hw_resized = load_image(self, i)
 | 
			
		||||
                    gb += self.ims[i].nbytes
 | 
			
		||||
                pbar.desc = f"{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})"
 | 
			
		||||
                pbar.desc = f'{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})'
 | 
			
		||||
            pbar.close()
 | 
			
		||||
 | 
			
		||||
    def cache_images_to_disk(self, i):
 | 
			
		||||
@ -155,7 +155,7 @@ class BaseDataset(Dataset):
 | 
			
		||||
        bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int)  # batch index
 | 
			
		||||
        nb = bi[-1] + 1  # number of batches
 | 
			
		||||
 | 
			
		||||
        s = np.array([x.pop("shape") for x in self.labels])  # hw
 | 
			
		||||
        s = np.array([x.pop('shape') for x in self.labels])  # hw
 | 
			
		||||
        ar = s[:, 0] / s[:, 1]  # aspect ratio
 | 
			
		||||
        irect = ar.argsort()
 | 
			
		||||
        self.im_files = [self.im_files[i] for i in irect]
 | 
			
		||||
@ -180,14 +180,14 @@ class BaseDataset(Dataset):
 | 
			
		||||
 | 
			
		||||
    def get_label_info(self, index):
 | 
			
		||||
        label = self.labels[index].copy()
 | 
			
		||||
        label.pop("shape", None)  # shape is for rect, remove it
 | 
			
		||||
        label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
 | 
			
		||||
        label["ratio_pad"] = (
 | 
			
		||||
            label["resized_shape"][0] / label["ori_shape"][0],
 | 
			
		||||
            label["resized_shape"][1] / label["ori_shape"][1],
 | 
			
		||||
        label.pop('shape', None)  # shape is for rect, remove it
 | 
			
		||||
        label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index)
 | 
			
		||||
        label['ratio_pad'] = (
 | 
			
		||||
            label['resized_shape'][0] / label['ori_shape'][0],
 | 
			
		||||
            label['resized_shape'][1] / label['ori_shape'][1],
 | 
			
		||||
        )  # for evaluation
 | 
			
		||||
        if self.rect:
 | 
			
		||||
            label["rect_shape"] = self.batch_shapes[self.batch[index]]
 | 
			
		||||
            label['rect_shape'] = self.batch_shapes[self.batch[index]]
 | 
			
		||||
        label = self.update_labels_info(label)
 | 
			
		||||
        return label
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -28,7 +28,7 @@ class InfiniteDataLoader(dataloader.DataLoader):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
        object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
 | 
			
		||||
        object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
 | 
			
		||||
        self.iterator = super().__iter__()
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
@ -61,9 +61,9 @@ def seed_worker(worker_id):
 | 
			
		||||
    random.seed(worker_seed)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, rank=-1, mode="train"):
 | 
			
		||||
    assert mode in ["train", "val"]
 | 
			
		||||
    shuffle = mode == "train"
 | 
			
		||||
def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, rank=-1, mode='train'):
 | 
			
		||||
    assert mode in ['train', 'val']
 | 
			
		||||
    shuffle = mode == 'train'
 | 
			
		||||
    if cfg.rect and shuffle:
 | 
			
		||||
        LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
 | 
			
		||||
        shuffle = False
 | 
			
		||||
@ -72,21 +72,21 @@ def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, ra
 | 
			
		||||
            img_path=img_path,
 | 
			
		||||
            imgsz=cfg.imgsz,
 | 
			
		||||
            batch_size=batch,
 | 
			
		||||
            augment=mode == "train",  # augmentation
 | 
			
		||||
            augment=mode == 'train',  # augmentation
 | 
			
		||||
            hyp=cfg,  # TODO: probably add a get_hyps_from_cfg function
 | 
			
		||||
            rect=cfg.rect or rect,  # rectangular batches
 | 
			
		||||
            cache=cfg.cache or None,
 | 
			
		||||
            single_cls=cfg.single_cls or False,
 | 
			
		||||
            stride=int(stride),
 | 
			
		||||
            pad=0.0 if mode == "train" else 0.5,
 | 
			
		||||
            prefix=colorstr(f"{mode}: "),
 | 
			
		||||
            use_segments=cfg.task == "segment",
 | 
			
		||||
            use_keypoints=cfg.task == "keypoint",
 | 
			
		||||
            pad=0.0 if mode == 'train' else 0.5,
 | 
			
		||||
            prefix=colorstr(f'{mode}: '),
 | 
			
		||||
            use_segments=cfg.task == 'segment',
 | 
			
		||||
            use_keypoints=cfg.task == 'keypoint',
 | 
			
		||||
            names=names)
 | 
			
		||||
 | 
			
		||||
    batch = min(batch, len(dataset))
 | 
			
		||||
    nd = torch.cuda.device_count()  # number of CUDA devices
 | 
			
		||||
    workers = cfg.workers if mode == "train" else cfg.workers * 2
 | 
			
		||||
    workers = cfg.workers if mode == 'train' else cfg.workers * 2
 | 
			
		||||
    nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers])  # number of workers
 | 
			
		||||
    sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
 | 
			
		||||
    loader = DataLoader if cfg.image_weights or cfg.close_mosaic else InfiniteDataLoader  # allow attribute updates
 | 
			
		||||
@ -98,7 +98,7 @@ def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, ra
 | 
			
		||||
                  num_workers=nw,
 | 
			
		||||
                  sampler=sampler,
 | 
			
		||||
                  pin_memory=PIN_MEMORY,
 | 
			
		||||
                  collate_fn=getattr(dataset, "collate_fn", None),
 | 
			
		||||
                  collate_fn=getattr(dataset, 'collate_fn', None),
 | 
			
		||||
                  worker_init_fn=seed_worker,
 | 
			
		||||
                  generator=generator), dataset
 | 
			
		||||
 | 
			
		||||
@ -151,7 +151,7 @@ def check_source(source):
 | 
			
		||||
        from_img = True
 | 
			
		||||
    else:
 | 
			
		||||
        raise Exception(
 | 
			
		||||
            "Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict")
 | 
			
		||||
            'Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict')
 | 
			
		||||
 | 
			
		||||
    return source, webcam, screenshot, from_img, in_memory
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -47,7 +47,7 @@ class LoadStreams:
 | 
			
		||||
                # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
 | 
			
		||||
                check_requirements(('pafy', 'youtube_dl==2020.12.2'))
 | 
			
		||||
                import pafy  # noqa
 | 
			
		||||
                s = pafy.new(s).getbest(preftype="mp4").url  # YouTube URL
 | 
			
		||||
                s = pafy.new(s).getbest(preftype='mp4').url  # YouTube URL
 | 
			
		||||
            s = eval(s) if s.isnumeric() else s  # i.e. s = '0' local webcam
 | 
			
		||||
            if s == 0 and (is_colab() or is_kaggle()):
 | 
			
		||||
                raise NotImplementedError("'source=0' webcam not supported in Colab and Kaggle notebooks. "
 | 
			
		||||
@ -65,7 +65,7 @@ class LoadStreams:
 | 
			
		||||
            if not success or self.imgs[i] is None:
 | 
			
		||||
                raise ConnectionError(f'{st}Failed to read images from {s}')
 | 
			
		||||
            self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
 | 
			
		||||
            LOGGER.info(f"{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)")
 | 
			
		||||
            LOGGER.info(f'{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)')
 | 
			
		||||
            self.threads[i].start()
 | 
			
		||||
        LOGGER.info('')  # newline
 | 
			
		||||
 | 
			
		||||
@ -145,11 +145,11 @@ class LoadScreenshots:
 | 
			
		||||
 | 
			
		||||
        # Parse monitor shape
 | 
			
		||||
        monitor = self.sct.monitors[self.screen]
 | 
			
		||||
        self.top = monitor["top"] if top is None else (monitor["top"] + top)
 | 
			
		||||
        self.left = monitor["left"] if left is None else (monitor["left"] + left)
 | 
			
		||||
        self.width = width or monitor["width"]
 | 
			
		||||
        self.height = height or monitor["height"]
 | 
			
		||||
        self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
 | 
			
		||||
        self.top = monitor['top'] if top is None else (monitor['top'] + top)
 | 
			
		||||
        self.left = monitor['left'] if left is None else (monitor['left'] + left)
 | 
			
		||||
        self.width = width or monitor['width']
 | 
			
		||||
        self.height = height or monitor['height']
 | 
			
		||||
        self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
 | 
			
		||||
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
        return self
 | 
			
		||||
@ -157,7 +157,7 @@ class LoadScreenshots:
 | 
			
		||||
    def __next__(self):
 | 
			
		||||
        # mss screen capture: get raw pixels from the screen as np array
 | 
			
		||||
        im0 = np.array(self.sct.grab(self.monitor))[:, :, :3]  # [:, :, :3] BGRA to BGR
 | 
			
		||||
        s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
 | 
			
		||||
        s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: '
 | 
			
		||||
 | 
			
		||||
        if self.transforms:
 | 
			
		||||
            im = self.transforms(im0)  # transforms
 | 
			
		||||
@ -172,7 +172,7 @@ class LoadScreenshots:
 | 
			
		||||
class LoadImages:
 | 
			
		||||
    # YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`
 | 
			
		||||
    def __init__(self, path, imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
 | 
			
		||||
        if isinstance(path, str) and Path(path).suffix == ".txt":  # *.txt file with img/vid/dir on each line
 | 
			
		||||
        if isinstance(path, str) and Path(path).suffix == '.txt':  # *.txt file with img/vid/dir on each line
 | 
			
		||||
            path = Path(path).read_text().rsplit()
 | 
			
		||||
        files = []
 | 
			
		||||
        for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
 | 
			
		||||
@ -290,12 +290,12 @@ class LoadPilAndNumpy:
 | 
			
		||||
        self.transforms = transforms
 | 
			
		||||
        self.mode = 'image'
 | 
			
		||||
        # generate fake paths
 | 
			
		||||
        self.paths = [f"image{i}.jpg" for i in range(len(self.im0))]
 | 
			
		||||
        self.paths = [f'image{i}.jpg' for i in range(len(self.im0))]
 | 
			
		||||
        self.bs = len(self.im0)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _single_check(im):
 | 
			
		||||
        assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}"
 | 
			
		||||
        assert isinstance(im, (Image.Image, np.ndarray)), f'Expected PIL/np.ndarray image type, but got {type(im)}'
 | 
			
		||||
        if isinstance(im, Image.Image):
 | 
			
		||||
            im = np.asarray(im)[:, :, ::-1]
 | 
			
		||||
            im = np.ascontiguousarray(im)  # contiguous
 | 
			
		||||
@ -338,16 +338,16 @@ def autocast_list(source):
 | 
			
		||||
        elif isinstance(im, (Image.Image, np.ndarray)):  # PIL or np Image
 | 
			
		||||
            files.append(im)
 | 
			
		||||
        else:
 | 
			
		||||
            raise TypeError(f"type {type(im).__name__} is not a supported Ultralytics prediction source type. \n"
 | 
			
		||||
                            f"See https://docs.ultralytics.com/predict for supported source types.")
 | 
			
		||||
            raise TypeError(f'type {type(im).__name__} is not a supported Ultralytics prediction source type. \n'
 | 
			
		||||
                            f'See https://docs.ultralytics.com/predict for supported source types.')
 | 
			
		||||
 | 
			
		||||
    return files
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
LOADERS = [LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots]
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    img = cv2.imread(str(ROOT / "assets/bus.jpg"))
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    img = cv2.imread(str(ROOT / 'assets/bus.jpg'))
 | 
			
		||||
    dataset = LoadPilAndNumpy(im0=img)
 | 
			
		||||
    for d in dataset:
 | 
			
		||||
        print(d[0])
 | 
			
		||||
 | 
			
		||||
@ -92,7 +92,7 @@ def exif_transpose(image):
 | 
			
		||||
        if method is not None:
 | 
			
		||||
            image = image.transpose(method)
 | 
			
		||||
            del exif[0x0112]
 | 
			
		||||
            image.info["exif"] = exif.tobytes()
 | 
			
		||||
            image.info['exif'] = exif.tobytes()
 | 
			
		||||
    return image
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -217,11 +217,11 @@ class LoadScreenshots:
 | 
			
		||||
 | 
			
		||||
        # Parse monitor shape
 | 
			
		||||
        monitor = self.sct.monitors[self.screen]
 | 
			
		||||
        self.top = monitor["top"] if top is None else (monitor["top"] + top)
 | 
			
		||||
        self.left = monitor["left"] if left is None else (monitor["left"] + left)
 | 
			
		||||
        self.width = width or monitor["width"]
 | 
			
		||||
        self.height = height or monitor["height"]
 | 
			
		||||
        self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
 | 
			
		||||
        self.top = monitor['top'] if top is None else (monitor['top'] + top)
 | 
			
		||||
        self.left = monitor['left'] if left is None else (monitor['left'] + left)
 | 
			
		||||
        self.width = width or monitor['width']
 | 
			
		||||
        self.height = height or monitor['height']
 | 
			
		||||
        self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
 | 
			
		||||
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
        return self
 | 
			
		||||
@ -229,7 +229,7 @@ class LoadScreenshots:
 | 
			
		||||
    def __next__(self):
 | 
			
		||||
        # mss screen capture: get raw pixels from the screen as np array
 | 
			
		||||
        im0 = np.array(self.sct.grab(self.monitor))[:, :, :3]  # [:, :, :3] BGRA to BGR
 | 
			
		||||
        s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
 | 
			
		||||
        s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: '
 | 
			
		||||
 | 
			
		||||
        if self.transforms:
 | 
			
		||||
            im = self.transforms(im0)  # transforms
 | 
			
		||||
@ -244,7 +244,7 @@ class LoadScreenshots:
 | 
			
		||||
class LoadImages:
 | 
			
		||||
    # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
 | 
			
		||||
    def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
 | 
			
		||||
        if isinstance(path, str) and Path(path).suffix == ".txt":  # *.txt file with img/vid/dir on each line
 | 
			
		||||
        if isinstance(path, str) and Path(path).suffix == '.txt':  # *.txt file with img/vid/dir on each line
 | 
			
		||||
            path = Path(path).read_text().rsplit()
 | 
			
		||||
        files = []
 | 
			
		||||
        for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
 | 
			
		||||
@ -363,7 +363,7 @@ class LoadStreams:
 | 
			
		||||
                # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
 | 
			
		||||
                check_requirements(('pafy', 'youtube_dl==2020.12.2'))
 | 
			
		||||
                import pafy
 | 
			
		||||
                s = pafy.new(s).getbest(preftype="mp4").url  # YouTube URL
 | 
			
		||||
                s = pafy.new(s).getbest(preftype='mp4').url  # YouTube URL
 | 
			
		||||
            s = eval(s) if s.isnumeric() else s  # i.e. s = '0' local webcam
 | 
			
		||||
            if s == 0:
 | 
			
		||||
                assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
 | 
			
		||||
@ -378,7 +378,7 @@ class LoadStreams:
 | 
			
		||||
 | 
			
		||||
            _, self.imgs[i] = cap.read()  # guarantee first frame
 | 
			
		||||
            self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
 | 
			
		||||
            LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
 | 
			
		||||
            LOGGER.info(f'{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)')
 | 
			
		||||
            self.threads[i].start()
 | 
			
		||||
        LOGGER.info('')  # newline
 | 
			
		||||
 | 
			
		||||
@ -500,7 +500,7 @@ class LoadImagesAndLabels(Dataset):
 | 
			
		||||
        # Display cache
 | 
			
		||||
        nf, nm, ne, nc, n = cache.pop('results')  # found, missing, empty, corrupt, total
 | 
			
		||||
        if exists and LOCAL_RANK in {-1, 0}:
 | 
			
		||||
            d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
 | 
			
		||||
            d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
 | 
			
		||||
            tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT)  # display cache results
 | 
			
		||||
            if cache['msgs']:
 | 
			
		||||
                LOGGER.info('\n'.join(cache['msgs']))  # display warnings
 | 
			
		||||
@ -604,8 +604,8 @@ class LoadImagesAndLabels(Dataset):
 | 
			
		||||
        mem = psutil.virtual_memory()
 | 
			
		||||
        cache = mem_required * (1 + safety_margin) < mem.available  # to cache or not to cache, that is the question
 | 
			
		||||
        if not cache:
 | 
			
		||||
            LOGGER.info(f"{prefix}{mem_required / gb:.1f}GB RAM required, "
 | 
			
		||||
                        f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, "
 | 
			
		||||
            LOGGER.info(f'{prefix}{mem_required / gb:.1f}GB RAM required, '
 | 
			
		||||
                        f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
 | 
			
		||||
                        f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
 | 
			
		||||
        return cache
 | 
			
		||||
 | 
			
		||||
@ -615,7 +615,7 @@ class LoadImagesAndLabels(Dataset):
 | 
			
		||||
            path.unlink()  # remove *.cache file if exists
 | 
			
		||||
        x = {}  # dict
 | 
			
		||||
        nm, nf, ne, nc, msgs = 0, 0, 0, 0, []  # number missing, found, empty, corrupt, messages
 | 
			
		||||
        desc = f"{prefix}Scanning {path.parent / path.stem}..."
 | 
			
		||||
        desc = f'{prefix}Scanning {path.parent / path.stem}...'
 | 
			
		||||
        total = len(self.im_files)
 | 
			
		||||
        with ThreadPool(NUM_THREADS) as pool:
 | 
			
		||||
            results = pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix)))
 | 
			
		||||
@ -629,7 +629,7 @@ class LoadImagesAndLabels(Dataset):
 | 
			
		||||
                    x[im_file] = [lb, shape, segments]
 | 
			
		||||
                if msg:
 | 
			
		||||
                    msgs.append(msg)
 | 
			
		||||
                pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
 | 
			
		||||
                pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt'
 | 
			
		||||
            pbar.close()
 | 
			
		||||
 | 
			
		||||
        if msgs:
 | 
			
		||||
@ -1060,7 +1060,7 @@ class HUBDatasetStats():
 | 
			
		||||
            if zipped:
 | 
			
		||||
                data['path'] = data_dir
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            raise Exception("error/HUB/dataset_stats/yaml_load") from e
 | 
			
		||||
            raise Exception('error/HUB/dataset_stats/yaml_load') from e
 | 
			
		||||
 | 
			
		||||
        check_det_dataset(data, autodownload)  # download dataset if missing
 | 
			
		||||
        self.hub_dir = Path(data['path'] + '-hub')
 | 
			
		||||
@ -1187,7 +1187,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
 | 
			
		||||
        else:  # read image
 | 
			
		||||
            im = cv2.imread(f)  # BGR
 | 
			
		||||
        if self.album_transforms:
 | 
			
		||||
            sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
 | 
			
		||||
            sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image']
 | 
			
		||||
        else:
 | 
			
		||||
            sample = self.torch_transforms(im)
 | 
			
		||||
        return sample, j
 | 
			
		||||
 | 
			
		||||
@ -28,7 +28,7 @@ class YOLODataset(BaseDataset):
 | 
			
		||||
                 cache=False,
 | 
			
		||||
                 augment=True,
 | 
			
		||||
                 hyp=None,
 | 
			
		||||
                 prefix="",
 | 
			
		||||
                 prefix='',
 | 
			
		||||
                 rect=False,
 | 
			
		||||
                 batch_size=None,
 | 
			
		||||
                 stride=32,
 | 
			
		||||
@ -40,14 +40,14 @@ class YOLODataset(BaseDataset):
 | 
			
		||||
        self.use_segments = use_segments
 | 
			
		||||
        self.use_keypoints = use_keypoints
 | 
			
		||||
        self.names = names
 | 
			
		||||
        assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
 | 
			
		||||
        assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.'
 | 
			
		||||
        super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls)
 | 
			
		||||
 | 
			
		||||
    def cache_labels(self, path=Path("./labels.cache")):
 | 
			
		||||
    def cache_labels(self, path=Path('./labels.cache')):
 | 
			
		||||
        # Cache dataset labels, check images and read shapes
 | 
			
		||||
        x = {"labels": []}
 | 
			
		||||
        x = {'labels': []}
 | 
			
		||||
        nm, nf, ne, nc, msgs = 0, 0, 0, 0, []  # number missing, found, empty, corrupt, messages
 | 
			
		||||
        desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
 | 
			
		||||
        desc = f'{self.prefix}Scanning {path.parent / path.stem}...'
 | 
			
		||||
        total = len(self.im_files)
 | 
			
		||||
        with ThreadPool(NUM_THREADS) as pool:
 | 
			
		||||
            results = pool.imap(func=verify_image_label,
 | 
			
		||||
@ -60,7 +60,7 @@ class YOLODataset(BaseDataset):
 | 
			
		||||
                ne += ne_f
 | 
			
		||||
                nc += nc_f
 | 
			
		||||
                if im_file:
 | 
			
		||||
                    x["labels"].append(
 | 
			
		||||
                    x['labels'].append(
 | 
			
		||||
                        dict(
 | 
			
		||||
                            im_file=im_file,
 | 
			
		||||
                            shape=shape,
 | 
			
		||||
@ -69,68 +69,68 @@ class YOLODataset(BaseDataset):
 | 
			
		||||
                            segments=segments,
 | 
			
		||||
                            keypoints=keypoint,
 | 
			
		||||
                            normalized=True,
 | 
			
		||||
                            bbox_format="xywh"))
 | 
			
		||||
                            bbox_format='xywh'))
 | 
			
		||||
                if msg:
 | 
			
		||||
                    msgs.append(msg)
 | 
			
		||||
                pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
 | 
			
		||||
                pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt'
 | 
			
		||||
            pbar.close()
 | 
			
		||||
 | 
			
		||||
        if msgs:
 | 
			
		||||
            LOGGER.info("\n".join(msgs))
 | 
			
		||||
            LOGGER.info('\n'.join(msgs))
 | 
			
		||||
        if nf == 0:
 | 
			
		||||
            LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}")
 | 
			
		||||
        x["hash"] = get_hash(self.label_files + self.im_files)
 | 
			
		||||
        x["results"] = nf, nm, ne, nc, len(self.im_files)
 | 
			
		||||
        x["msgs"] = msgs  # warnings
 | 
			
		||||
        x["version"] = self.cache_version  # cache version
 | 
			
		||||
            LOGGER.warning(f'{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}')
 | 
			
		||||
        x['hash'] = get_hash(self.label_files + self.im_files)
 | 
			
		||||
        x['results'] = nf, nm, ne, nc, len(self.im_files)
 | 
			
		||||
        x['msgs'] = msgs  # warnings
 | 
			
		||||
        x['version'] = self.cache_version  # cache version
 | 
			
		||||
        if is_dir_writeable(path.parent):
 | 
			
		||||
            if path.exists():
 | 
			
		||||
                path.unlink()  # remove *.cache file if exists
 | 
			
		||||
            np.save(str(path), x)  # save cache for next time
 | 
			
		||||
            path.with_suffix(".cache.npy").rename(path)  # remove .npy suffix
 | 
			
		||||
            LOGGER.info(f"{self.prefix}New cache created: {path}")
 | 
			
		||||
            path.with_suffix('.cache.npy').rename(path)  # remove .npy suffix
 | 
			
		||||
            LOGGER.info(f'{self.prefix}New cache created: {path}')
 | 
			
		||||
        else:
 | 
			
		||||
            LOGGER.warning(f"{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
 | 
			
		||||
            LOGGER.warning(f'{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.')
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def get_labels(self):
 | 
			
		||||
        self.label_files = img2label_paths(self.im_files)
 | 
			
		||||
        cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
 | 
			
		||||
        cache_path = Path(self.label_files[0]).parent.with_suffix('.cache')
 | 
			
		||||
        try:
 | 
			
		||||
            cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True  # load dict
 | 
			
		||||
            assert cache["version"] == self.cache_version  # matches current version
 | 
			
		||||
            assert cache["hash"] == get_hash(self.label_files + self.im_files)  # identical hash
 | 
			
		||||
            assert cache['version'] == self.cache_version  # matches current version
 | 
			
		||||
            assert cache['hash'] == get_hash(self.label_files + self.im_files)  # identical hash
 | 
			
		||||
        except (FileNotFoundError, AssertionError, AttributeError):
 | 
			
		||||
            cache, exists = self.cache_labels(cache_path), False  # run cache ops
 | 
			
		||||
 | 
			
		||||
        # Display cache
 | 
			
		||||
        nf, nm, ne, nc, n = cache.pop("results")  # found, missing, empty, corrupt, total
 | 
			
		||||
        nf, nm, ne, nc, n = cache.pop('results')  # found, missing, empty, corrupt, total
 | 
			
		||||
        if exists and LOCAL_RANK in {-1, 0}:
 | 
			
		||||
            d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
 | 
			
		||||
            d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
 | 
			
		||||
            tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT)  # display cache results
 | 
			
		||||
            if cache["msgs"]:
 | 
			
		||||
                LOGGER.info("\n".join(cache["msgs"]))  # display warnings
 | 
			
		||||
            if cache['msgs']:
 | 
			
		||||
                LOGGER.info('\n'.join(cache['msgs']))  # display warnings
 | 
			
		||||
        if nf == 0:  # number of labels found
 | 
			
		||||
            raise FileNotFoundError(f"{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}")
 | 
			
		||||
            raise FileNotFoundError(f'{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}')
 | 
			
		||||
 | 
			
		||||
        # Read cache
 | 
			
		||||
        [cache.pop(k) for k in ("hash", "version", "msgs")]  # remove items
 | 
			
		||||
        labels = cache["labels"]
 | 
			
		||||
        self.im_files = [lb["im_file"] for lb in labels]  # update im_files
 | 
			
		||||
        [cache.pop(k) for k in ('hash', 'version', 'msgs')]  # remove items
 | 
			
		||||
        labels = cache['labels']
 | 
			
		||||
        self.im_files = [lb['im_file'] for lb in labels]  # update im_files
 | 
			
		||||
 | 
			
		||||
        # Check if the dataset is all boxes or all segments
 | 
			
		||||
        len_cls = sum(len(lb["cls"]) for lb in labels)
 | 
			
		||||
        len_boxes = sum(len(lb["bboxes"]) for lb in labels)
 | 
			
		||||
        len_segments = sum(len(lb["segments"]) for lb in labels)
 | 
			
		||||
        len_cls = sum(len(lb['cls']) for lb in labels)
 | 
			
		||||
        len_boxes = sum(len(lb['bboxes']) for lb in labels)
 | 
			
		||||
        len_segments = sum(len(lb['segments']) for lb in labels)
 | 
			
		||||
        if len_segments and len_boxes != len_segments:
 | 
			
		||||
            LOGGER.warning(
 | 
			
		||||
                f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
 | 
			
		||||
                f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
 | 
			
		||||
                "To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.")
 | 
			
		||||
                f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, '
 | 
			
		||||
                f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. '
 | 
			
		||||
                'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.')
 | 
			
		||||
            for lb in labels:
 | 
			
		||||
                lb["segments"] = []
 | 
			
		||||
                lb['segments'] = []
 | 
			
		||||
        if len_cls == 0:
 | 
			
		||||
            raise ValueError(f"All labels empty in {cache_path}, can not start training without labels. {HELP_URL}")
 | 
			
		||||
            raise ValueError(f'All labels empty in {cache_path}, can not start training without labels. {HELP_URL}')
 | 
			
		||||
        return labels
 | 
			
		||||
 | 
			
		||||
    # TODO: use hyp config to set all these augmentations
 | 
			
		||||
@ -142,7 +142,7 @@ class YOLODataset(BaseDataset):
 | 
			
		||||
        else:
 | 
			
		||||
            transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
 | 
			
		||||
        transforms.append(
 | 
			
		||||
            Format(bbox_format="xywh",
 | 
			
		||||
            Format(bbox_format='xywh',
 | 
			
		||||
                   normalize=True,
 | 
			
		||||
                   return_mask=self.use_segments,
 | 
			
		||||
                   return_keypoint=self.use_keypoints,
 | 
			
		||||
@ -161,12 +161,12 @@ class YOLODataset(BaseDataset):
 | 
			
		||||
        """custom your label format here"""
 | 
			
		||||
        # NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
 | 
			
		||||
        # we can make it also support classification and semantic segmentation by add or remove some dict keys there.
 | 
			
		||||
        bboxes = label.pop("bboxes")
 | 
			
		||||
        segments = label.pop("segments")
 | 
			
		||||
        keypoints = label.pop("keypoints", None)
 | 
			
		||||
        bbox_format = label.pop("bbox_format")
 | 
			
		||||
        normalized = label.pop("normalized")
 | 
			
		||||
        label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
 | 
			
		||||
        bboxes = label.pop('bboxes')
 | 
			
		||||
        segments = label.pop('segments')
 | 
			
		||||
        keypoints = label.pop('keypoints', None)
 | 
			
		||||
        bbox_format = label.pop('bbox_format')
 | 
			
		||||
        normalized = label.pop('normalized')
 | 
			
		||||
        label['instances'] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
 | 
			
		||||
        return label
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
@ -176,15 +176,15 @@ class YOLODataset(BaseDataset):
 | 
			
		||||
        values = list(zip(*[list(b.values()) for b in batch]))
 | 
			
		||||
        for i, k in enumerate(keys):
 | 
			
		||||
            value = values[i]
 | 
			
		||||
            if k == "img":
 | 
			
		||||
            if k == 'img':
 | 
			
		||||
                value = torch.stack(value, 0)
 | 
			
		||||
            if k in ["masks", "keypoints", "bboxes", "cls"]:
 | 
			
		||||
            if k in ['masks', 'keypoints', 'bboxes', 'cls']:
 | 
			
		||||
                value = torch.cat(value, 0)
 | 
			
		||||
            new_batch[k] = value
 | 
			
		||||
        new_batch["batch_idx"] = list(new_batch["batch_idx"])
 | 
			
		||||
        for i in range(len(new_batch["batch_idx"])):
 | 
			
		||||
            new_batch["batch_idx"][i] += i  # add target image index for build_targets()
 | 
			
		||||
        new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
 | 
			
		||||
        new_batch['batch_idx'] = list(new_batch['batch_idx'])
 | 
			
		||||
        for i in range(len(new_batch['batch_idx'])):
 | 
			
		||||
            new_batch['batch_idx'][i] += i  # add target image index for build_targets()
 | 
			
		||||
        new_batch['batch_idx'] = torch.cat(new_batch['batch_idx'], 0)
 | 
			
		||||
        return new_batch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -202,9 +202,9 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
 | 
			
		||||
        super().__init__(root=root)
 | 
			
		||||
        self.torch_transforms = classify_transforms(imgsz)
 | 
			
		||||
        self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
 | 
			
		||||
        self.cache_ram = cache is True or cache == "ram"
 | 
			
		||||
        self.cache_disk = cache == "disk"
 | 
			
		||||
        self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples]  # file, index, npy, im
 | 
			
		||||
        self.cache_ram = cache is True or cache == 'ram'
 | 
			
		||||
        self.cache_disk = cache == 'disk'
 | 
			
		||||
        self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples]  # file, index, npy, im
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, i):
 | 
			
		||||
        f, j, fn, im = self.samples[i]  # filename, index, filename.with_suffix('.npy'), image
 | 
			
		||||
@ -217,7 +217,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
 | 
			
		||||
        else:  # read image
 | 
			
		||||
            im = cv2.imread(f)  # BGR
 | 
			
		||||
        if self.album_transforms:
 | 
			
		||||
            sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
 | 
			
		||||
            sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image']
 | 
			
		||||
        else:
 | 
			
		||||
            sample = self.torch_transforms(im)
 | 
			
		||||
        return {'img': sample, 'cls': j}
 | 
			
		||||
 | 
			
		||||
@ -25,15 +25,15 @@ class MixAndRectDataset:
 | 
			
		||||
        labels = deepcopy(self.dataset[index])
 | 
			
		||||
        for transform in self.dataset.transforms.tolist():
 | 
			
		||||
            # mosaic and mixup
 | 
			
		||||
            if hasattr(transform, "get_indexes"):
 | 
			
		||||
            if hasattr(transform, 'get_indexes'):
 | 
			
		||||
                indexes = transform.get_indexes(self.dataset)
 | 
			
		||||
                if not isinstance(indexes, collections.abc.Sequence):
 | 
			
		||||
                    indexes = [indexes]
 | 
			
		||||
                mix_labels = [deepcopy(self.dataset[index]) for index in indexes]
 | 
			
		||||
                labels["mix_labels"] = mix_labels
 | 
			
		||||
                labels['mix_labels'] = mix_labels
 | 
			
		||||
            if self.dataset.rect and isinstance(transform, LetterBox):
 | 
			
		||||
                transform.new_shape = self.dataset.batch_shapes[self.dataset.batch[index]]
 | 
			
		||||
            labels = transform(labels)
 | 
			
		||||
            if "mix_labels" in labels:
 | 
			
		||||
                labels.pop("mix_labels")
 | 
			
		||||
            if 'mix_labels' in labels:
 | 
			
		||||
                labels.pop('mix_labels')
 | 
			
		||||
        return labels
 | 
			
		||||
 | 
			
		||||
@ -55,4 +55,4 @@ download: |
 | 
			
		||||
              for r in x[images == im]:
 | 
			
		||||
                  w, h = r[6], r[7]  # image width, height
 | 
			
		||||
                  xywh = xyxy2xywh(np.array([[r[1] / w, r[2] / h, r[3] / w, r[4] / h]]))[0]  # instance
 | 
			
		||||
                  f.write(f"{cls} {xywh[0]:.5f} {xywh[1]:.5f} {xywh[2]:.5f} {xywh[3]:.5f}\n")  # write label
 | 
			
		||||
                  f.write(f"{cls} {xywh[0]:.5f} {xywh[1]:.5f} {xywh[2]:.5f} {xywh[3]:.5f}\n")  # write label
 | 
			
		||||
 | 
			
		||||
@ -112,4 +112,4 @@ download: |
 | 
			
		||||
  urls = ['http://images.cocodataset.org/zips/train2017.zip',  # 19G, 118k images
 | 
			
		||||
          'http://images.cocodataset.org/zips/val2017.zip',  # 1G, 5k images
 | 
			
		||||
          'http://images.cocodataset.org/zips/test2017.zip']  # 7G, 41k images (optional)
 | 
			
		||||
  download(urls, dir=dir / 'images', threads=3)
 | 
			
		||||
  download(urls, dir=dir / 'images', threads=3)
 | 
			
		||||
 | 
			
		||||
@ -98,4 +98,4 @@ names:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Download script/URL (optional)
 | 
			
		||||
download: https://ultralytics.com/assets/coco128-seg.zip
 | 
			
		||||
download: https://ultralytics.com/assets/coco128-seg.zip
 | 
			
		||||
 | 
			
		||||
@ -98,4 +98,4 @@ names:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Download script/URL (optional)
 | 
			
		||||
download: https://ultralytics.com/assets/coco128.zip
 | 
			
		||||
download: https://ultralytics.com/assets/coco128.zip
 | 
			
		||||
 | 
			
		||||
@ -98,4 +98,4 @@ names:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Download script/URL (optional)
 | 
			
		||||
download: https://ultralytics.com/assets/coco8-seg.zip
 | 
			
		||||
download: https://ultralytics.com/assets/coco8-seg.zip
 | 
			
		||||
 | 
			
		||||
@ -98,4 +98,4 @@ names:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Download script/URL (optional)
 | 
			
		||||
download: https://ultralytics.com/assets/coco8.zip
 | 
			
		||||
download: https://ultralytics.com/assets/coco8.zip
 | 
			
		||||
 | 
			
		||||
@ -18,32 +18,32 @@ from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii
 | 
			
		||||
from ultralytics.yolo.utils.downloads import download, safe_download
 | 
			
		||||
from ultralytics.yolo.utils.ops import segments2boxes
 | 
			
		||||
 | 
			
		||||
HELP_URL = "See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data"
 | 
			
		||||
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"  # include image suffixes
 | 
			
		||||
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv"  # include video suffixes
 | 
			
		||||
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1))  # https://pytorch.org/docs/stable/elastic/run.html
 | 
			
		||||
HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
 | 
			
		||||
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm'  # include image suffixes
 | 
			
		||||
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv'  # include video suffixes
 | 
			
		||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))  # https://pytorch.org/docs/stable/elastic/run.html
 | 
			
		||||
RANK = int(os.getenv('RANK', -1))
 | 
			
		||||
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true"  # global pin_memory for dataloaders
 | 
			
		||||
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true'  # global pin_memory for dataloaders
 | 
			
		||||
IMAGENET_MEAN = 0.485, 0.456, 0.406  # RGB mean
 | 
			
		||||
IMAGENET_STD = 0.229, 0.224, 0.225  # RGB standard deviation
 | 
			
		||||
 | 
			
		||||
# Get orientation exif tag
 | 
			
		||||
for orientation in ExifTags.TAGS.keys():
 | 
			
		||||
    if ExifTags.TAGS[orientation] == "Orientation":
 | 
			
		||||
    if ExifTags.TAGS[orientation] == 'Orientation':
 | 
			
		||||
        break
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def img2label_paths(img_paths):
 | 
			
		||||
    # Define label paths as a function of image paths
 | 
			
		||||
    sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}"  # /images/, /labels/ substrings
 | 
			
		||||
    return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
 | 
			
		||||
    sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}'  # /images/, /labels/ substrings
 | 
			
		||||
    return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_hash(paths):
 | 
			
		||||
    # Returns a single hash value of a list of paths (files or dirs)
 | 
			
		||||
    size = sum(os.path.getsize(p) for p in paths if os.path.exists(p))  # sizes
 | 
			
		||||
    h = hashlib.sha256(str(size).encode())  # hash sizes
 | 
			
		||||
    h.update("".join(paths).encode())  # hash paths
 | 
			
		||||
    h.update(''.join(paths).encode())  # hash paths
 | 
			
		||||
    return h.hexdigest()  # return hash
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -61,21 +61,21 @@ def verify_image_label(args):
 | 
			
		||||
    # Verify one image-label pair
 | 
			
		||||
    im_file, lb_file, prefix, keypoint, num_cls = args
 | 
			
		||||
    # number (missing, found, empty, corrupt), message, segments, keypoints
 | 
			
		||||
    nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
 | 
			
		||||
    nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, '', [], None
 | 
			
		||||
    try:
 | 
			
		||||
        # verify images
 | 
			
		||||
        im = Image.open(im_file)
 | 
			
		||||
        im.verify()  # PIL verify
 | 
			
		||||
        shape = exif_size(im)  # image size
 | 
			
		||||
        shape = (shape[1], shape[0])  # hw
 | 
			
		||||
        assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
 | 
			
		||||
        assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
 | 
			
		||||
        if im.format.lower() in ("jpg", "jpeg"):
 | 
			
		||||
            with open(im_file, "rb") as f:
 | 
			
		||||
        assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
 | 
			
		||||
        assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
 | 
			
		||||
        if im.format.lower() in ('jpg', 'jpeg'):
 | 
			
		||||
            with open(im_file, 'rb') as f:
 | 
			
		||||
                f.seek(-2, 2)
 | 
			
		||||
                if f.read() != b"\xff\xd9":  # corrupt JPEG
 | 
			
		||||
                    ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
 | 
			
		||||
                    msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
 | 
			
		||||
                if f.read() != b'\xff\xd9':  # corrupt JPEG
 | 
			
		||||
                    ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
 | 
			
		||||
                    msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
 | 
			
		||||
 | 
			
		||||
        # verify labels
 | 
			
		||||
        if os.path.isfile(lb_file):
 | 
			
		||||
@ -90,31 +90,31 @@ def verify_image_label(args):
 | 
			
		||||
            nl = len(lb)
 | 
			
		||||
            if nl:
 | 
			
		||||
                if keypoint:
 | 
			
		||||
                    assert lb.shape[1] == 56, "labels require 56 columns each"
 | 
			
		||||
                    assert (lb[:, 5::3] <= 1).all(), "non-normalized or out of bounds coordinate labels"
 | 
			
		||||
                    assert (lb[:, 6::3] <= 1).all(), "non-normalized or out of bounds coordinate labels"
 | 
			
		||||
                    assert lb.shape[1] == 56, 'labels require 56 columns each'
 | 
			
		||||
                    assert (lb[:, 5::3] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
 | 
			
		||||
                    assert (lb[:, 6::3] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
 | 
			
		||||
                    kpts = np.zeros((lb.shape[0], 39))
 | 
			
		||||
                    for i in range(len(lb)):
 | 
			
		||||
                        kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5, 3))  # remove occlusion param from GT
 | 
			
		||||
                        kpts[i] = np.hstack((lb[i, :5], kpt))
 | 
			
		||||
                    lb = kpts
 | 
			
		||||
                    assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion parameter"
 | 
			
		||||
                    assert lb.shape[1] == 39, 'labels require 39 columns each after removing occlusion parameter'
 | 
			
		||||
                else:
 | 
			
		||||
                    assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
 | 
			
		||||
                    assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
 | 
			
		||||
                    assert (lb[:, 1:] <= 1).all(), \
 | 
			
		||||
                        f"non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}"
 | 
			
		||||
                        f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
 | 
			
		||||
                # All labels
 | 
			
		||||
                max_cls = int(lb[:, 0].max())  # max label count
 | 
			
		||||
                assert max_cls <= num_cls, \
 | 
			
		||||
                    f'Label class {max_cls} exceeds dataset class count {num_cls}. ' \
 | 
			
		||||
                    f'Possible class labels are 0-{num_cls - 1}'
 | 
			
		||||
                assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}"
 | 
			
		||||
                assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
 | 
			
		||||
                _, i = np.unique(lb, axis=0, return_index=True)
 | 
			
		||||
                if len(i) < nl:  # duplicate row check
 | 
			
		||||
                    lb = lb[i]  # remove duplicates
 | 
			
		||||
                    if segments:
 | 
			
		||||
                        segments = [segments[x] for x in i]
 | 
			
		||||
                    msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed"
 | 
			
		||||
                    msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed'
 | 
			
		||||
            else:
 | 
			
		||||
                ne = 1  # label empty
 | 
			
		||||
                lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
 | 
			
		||||
@ -127,7 +127,7 @@ def verify_image_label(args):
 | 
			
		||||
        return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        nc = 1
 | 
			
		||||
        msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
 | 
			
		||||
        msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
 | 
			
		||||
        return [None, None, None, None, None, nm, nf, ne, nc, msg]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -248,8 +248,8 @@ def check_det_dataset(dataset, autodownload=True):
 | 
			
		||||
            else:  # python script
 | 
			
		||||
                r = exec(s, {'yaml': data})  # return None
 | 
			
		||||
            dt = f'({round(time.time() - t, 1)}s)'
 | 
			
		||||
            s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
 | 
			
		||||
            LOGGER.info(f"Dataset download {s}\n")
 | 
			
		||||
            s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt} ❌'
 | 
			
		||||
            LOGGER.info(f'Dataset download {s}\n')
 | 
			
		||||
    check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf')  # download fonts
 | 
			
		||||
 | 
			
		||||
    return data  # dictionary
 | 
			
		||||
@ -284,9 +284,9 @@ def check_cls_dataset(dataset: str):
 | 
			
		||||
            download(url, dir=data_dir.parent)
 | 
			
		||||
        s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
 | 
			
		||||
        LOGGER.info(s)
 | 
			
		||||
    train_set = data_dir / "train"
 | 
			
		||||
    train_set = data_dir / 'train'
 | 
			
		||||
    test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val'  # data/test or data/val
 | 
			
		||||
    nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()])  # number of classes
 | 
			
		||||
    names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()]  # class names list
 | 
			
		||||
    names = dict(enumerate(sorted(names)))
 | 
			
		||||
    return {"train": train_set, "val": test_set, "nc": nc, "names": names}
 | 
			
		||||
    return {'train': train_set, 'val': test_set, 'nc': nc, 'names': names}
 | 
			
		||||
 | 
			
		||||
@ -144,7 +144,7 @@ class Exporter:
 | 
			
		||||
 | 
			
		||||
    @smart_inference_mode()
 | 
			
		||||
    def __call__(self, model=None):
 | 
			
		||||
        self.run_callbacks("on_export_start")
 | 
			
		||||
        self.run_callbacks('on_export_start')
 | 
			
		||||
        t = time.time()
 | 
			
		||||
        format = self.args.format.lower()  # to lowercase
 | 
			
		||||
        if format in {'tensorrt', 'trt'}:  # engine aliases
 | 
			
		||||
@ -207,7 +207,7 @@ class Exporter:
 | 
			
		||||
        self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
 | 
			
		||||
        self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
 | 
			
		||||
        self.metadata = {
 | 
			
		||||
            'description': f"Ultralytics {self.pretty_name} model trained on {self.args.data}",
 | 
			
		||||
            'description': f'Ultralytics {self.pretty_name} model trained on {self.args.data}',
 | 
			
		||||
            'author': 'Ultralytics',
 | 
			
		||||
            'license': 'GPL-3.0 https://ultralytics.com/license',
 | 
			
		||||
            'version': __version__,
 | 
			
		||||
@ -215,7 +215,7 @@ class Exporter:
 | 
			
		||||
            'names': model.names}  # model metadata
 | 
			
		||||
 | 
			
		||||
        LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and "
 | 
			
		||||
                    f"output shape(s) {self.output_shape} ({file_size(file):.1f} MB)")
 | 
			
		||||
                    f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)')
 | 
			
		||||
 | 
			
		||||
        # Exports
 | 
			
		||||
        f = [''] * len(fmts)  # exported filenames
 | 
			
		||||
@ -259,15 +259,15 @@ class Exporter:
 | 
			
		||||
            s = '' if square else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " \
 | 
			
		||||
                                  f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
 | 
			
		||||
            imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(' ', '')
 | 
			
		||||
            data = f"data={self.args.data}" if model.task == 'segment' and format == 'pb' else ''
 | 
			
		||||
            data = f'data={self.args.data}' if model.task == 'segment' and format == 'pb' else ''
 | 
			
		||||
            LOGGER.info(
 | 
			
		||||
                f'\nExport complete ({time.time() - t:.1f}s)'
 | 
			
		||||
                f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
 | 
			
		||||
                f"\nPredict:         yolo task={model.task} mode=predict model={f} imgsz={imgsz} {data}"
 | 
			
		||||
                f"\nValidate:        yolo task={model.task} mode=val model={f} imgsz={imgsz} data={self.args.data} {s}"
 | 
			
		||||
                f"\nVisualize:       https://netron.app")
 | 
			
		||||
                f'\nPredict:         yolo task={model.task} mode=predict model={f} imgsz={imgsz} {data}'
 | 
			
		||||
                f'\nValidate:        yolo task={model.task} mode=val model={f} imgsz={imgsz} data={self.args.data} {s}'
 | 
			
		||||
                f'\nVisualize:       https://netron.app')
 | 
			
		||||
 | 
			
		||||
        self.run_callbacks("on_export_end")
 | 
			
		||||
        self.run_callbacks('on_export_end')
 | 
			
		||||
        return f  # return list of exported files/dirs
 | 
			
		||||
 | 
			
		||||
    @try_export
 | 
			
		||||
@ -277,7 +277,7 @@ class Exporter:
 | 
			
		||||
        f = self.file.with_suffix('.torchscript')
 | 
			
		||||
 | 
			
		||||
        ts = torch.jit.trace(self.model, self.im, strict=False)
 | 
			
		||||
        d = {"shape": self.im.shape, "stride": int(max(self.model.stride)), "names": self.model.names}
 | 
			
		||||
        d = {'shape': self.im.shape, 'stride': int(max(self.model.stride)), 'names': self.model.names}
 | 
			
		||||
        extra_files = {'config.txt': json.dumps(d)}  # torch._C.ExtraFilesMap()
 | 
			
		||||
        if self.args.optimize:  # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
 | 
			
		||||
            LOGGER.info(f'{prefix} optimizing for mobile...')
 | 
			
		||||
@ -354,7 +354,7 @@ class Exporter:
 | 
			
		||||
 | 
			
		||||
        ov_model = mo.convert_model(f_onnx,
 | 
			
		||||
                                    model_name=self.pretty_name,
 | 
			
		||||
                                    framework="onnx",
 | 
			
		||||
                                    framework='onnx',
 | 
			
		||||
                                    compress_to_fp16=self.args.half)  # export
 | 
			
		||||
        ov.serialize(ov_model, f_ov)  # save
 | 
			
		||||
        yaml_save(Path(f) / 'metadata.yaml', self.metadata)  # add metadata.yaml
 | 
			
		||||
@ -471,7 +471,7 @@ class Exporter:
 | 
			
		||||
        if self.args.dynamic:
 | 
			
		||||
            shape = self.im.shape
 | 
			
		||||
            if shape[0] <= 1:
 | 
			
		||||
                LOGGER.warning(f"{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument")
 | 
			
		||||
                LOGGER.warning(f'{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument')
 | 
			
		||||
            profile = builder.create_optimization_profile()
 | 
			
		||||
            for inp in inputs:
 | 
			
		||||
                profile.set_shape(inp.name, (1, *shape[1:]), (max(1, shape[0] // 2), *shape[1:]), shape)
 | 
			
		||||
@ -509,8 +509,8 @@ class Exporter:
 | 
			
		||||
        except ImportError:
 | 
			
		||||
            check_requirements(f"tensorflow{'' if CUDA else '-macos' if MACOS else '-cpu' if LINUX else ''}")
 | 
			
		||||
            import tensorflow as tf  # noqa
 | 
			
		||||
        check_requirements(("onnx", "onnx2tf", "sng4onnx", "onnxsim", "onnx_graphsurgeon", "tflite_support"),
 | 
			
		||||
                           cmds="--extra-index-url https://pypi.ngc.nvidia.com")
 | 
			
		||||
        check_requirements(('onnx', 'onnx2tf', 'sng4onnx', 'onnxsim', 'onnx_graphsurgeon', 'tflite_support'),
 | 
			
		||||
                           cmds='--extra-index-url https://pypi.ngc.nvidia.com')
 | 
			
		||||
 | 
			
		||||
        LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
 | 
			
		||||
        f = str(self.file).replace(self.file.suffix, '_saved_model')
 | 
			
		||||
@ -632,7 +632,7 @@ class Exporter:
 | 
			
		||||
            converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
 | 
			
		||||
 | 
			
		||||
        tflite_model = converter.convert()
 | 
			
		||||
        open(f, "wb").write(tflite_model)
 | 
			
		||||
        open(f, 'wb').write(tflite_model)
 | 
			
		||||
        return f, None
 | 
			
		||||
 | 
			
		||||
    @try_export
 | 
			
		||||
@ -656,7 +656,7 @@ class Exporter:
 | 
			
		||||
        LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
 | 
			
		||||
        f = str(tflite_model).replace('.tflite', '_edgetpu.tflite')  # Edge TPU model
 | 
			
		||||
 | 
			
		||||
        cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {self.file.parent} {tflite_model}"
 | 
			
		||||
        cmd = f'edgetpu_compiler -s -d -k 10 --out_dir {self.file.parent} {tflite_model}'
 | 
			
		||||
        subprocess.run(cmd.split(), check=True)
 | 
			
		||||
        self._add_tflite_metadata(f)
 | 
			
		||||
        return f, None
 | 
			
		||||
@ -707,8 +707,8 @@ class Exporter:
 | 
			
		||||
 | 
			
		||||
        # Creates input info.
 | 
			
		||||
        input_meta = _metadata_fb.TensorMetadataT()
 | 
			
		||||
        input_meta.name = "image"
 | 
			
		||||
        input_meta.description = "Input image to be detected."
 | 
			
		||||
        input_meta.name = 'image'
 | 
			
		||||
        input_meta.description = 'Input image to be detected.'
 | 
			
		||||
        input_meta.content = _metadata_fb.ContentT()
 | 
			
		||||
        input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
 | 
			
		||||
        input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB
 | 
			
		||||
@ -716,8 +716,8 @@ class Exporter:
 | 
			
		||||
 | 
			
		||||
        # Creates output info.
 | 
			
		||||
        output_meta = _metadata_fb.TensorMetadataT()
 | 
			
		||||
        output_meta.name = "output"
 | 
			
		||||
        output_meta.description = "Coordinates of detected objects, class labels, and confidence score."
 | 
			
		||||
        output_meta.name = 'output'
 | 
			
		||||
        output_meta.description = 'Coordinates of detected objects, class labels, and confidence score.'
 | 
			
		||||
 | 
			
		||||
        # Label file
 | 
			
		||||
        tmp_file = Path('/tmp/meta.txt')
 | 
			
		||||
@ -868,8 +868,8 @@ class Exporter:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def export(cfg=DEFAULT_CFG):
 | 
			
		||||
    cfg.model = cfg.model or "yolov8n.yaml"
 | 
			
		||||
    cfg.format = cfg.format or "torchscript"
 | 
			
		||||
    cfg.model = cfg.model or 'yolov8n.yaml'
 | 
			
		||||
    cfg.format = cfg.format or 'torchscript'
 | 
			
		||||
 | 
			
		||||
    # exporter = Exporter(cfg)
 | 
			
		||||
    #
 | 
			
		||||
@ -888,7 +888,7 @@ def export(cfg=DEFAULT_CFG):
 | 
			
		||||
    model.export(**vars(cfg))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    """
 | 
			
		||||
    CLI:
 | 
			
		||||
    yolo mode=export model=yolov8n.yaml format=onnx
 | 
			
		||||
 | 
			
		||||
@ -16,13 +16,13 @@ from ultralytics.yolo.utils.torch_utils import smart_inference_mode
 | 
			
		||||
 | 
			
		||||
# Map head to model, trainer, validator, and predictor classes
 | 
			
		||||
MODEL_MAP = {
 | 
			
		||||
    "classify": [
 | 
			
		||||
    'classify': [
 | 
			
		||||
        ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator',
 | 
			
		||||
        'yolo.TYPE.classify.ClassificationPredictor'],
 | 
			
		||||
    "detect": [
 | 
			
		||||
    'detect': [
 | 
			
		||||
        DetectionModel, 'yolo.TYPE.detect.DetectionTrainer', 'yolo.TYPE.detect.DetectionValidator',
 | 
			
		||||
        'yolo.TYPE.detect.DetectionPredictor'],
 | 
			
		||||
    "segment": [
 | 
			
		||||
    'segment': [
 | 
			
		||||
        SegmentationModel, 'yolo.TYPE.segment.SegmentationTrainer', 'yolo.TYPE.segment.SegmentationValidator',
 | 
			
		||||
        'yolo.TYPE.segment.SegmentationPredictor']}
 | 
			
		||||
 | 
			
		||||
@ -34,7 +34,7 @@ class YOLO:
 | 
			
		||||
    A python interface which emulates a model-like behaviour by wrapping trainers.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, model='yolov8n.pt', type="v8") -> None:
 | 
			
		||||
    def __init__(self, model='yolov8n.pt', type='v8') -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Initializes the YOLO object.
 | 
			
		||||
 | 
			
		||||
@ -94,7 +94,7 @@ class YOLO:
 | 
			
		||||
        suffix = Path(weights).suffix
 | 
			
		||||
        if suffix == '.pt':
 | 
			
		||||
            self.model, self.ckpt = attempt_load_one_weight(weights)
 | 
			
		||||
            self.task = self.model.args["task"]
 | 
			
		||||
            self.task = self.model.args['task']
 | 
			
		||||
            self.overrides = self.model.args
 | 
			
		||||
            self._reset_ckpt_args(self.overrides)
 | 
			
		||||
        else:
 | 
			
		||||
@ -111,7 +111,7 @@ class YOLO:
 | 
			
		||||
        """
 | 
			
		||||
        if not isinstance(self.model, nn.Module):
 | 
			
		||||
            raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
 | 
			
		||||
                            f"PyTorch models can be used to train, val, predict and export, i.e. "
 | 
			
		||||
                            f'PyTorch models can be used to train, val, predict and export, i.e. '
 | 
			
		||||
                            f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
 | 
			
		||||
                            f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
 | 
			
		||||
 | 
			
		||||
@ -155,11 +155,11 @@ class YOLO:
 | 
			
		||||
            (List[ultralytics.yolo.engine.results.Results]): The prediction results.
 | 
			
		||||
        """
 | 
			
		||||
        overrides = self.overrides.copy()
 | 
			
		||||
        overrides["conf"] = 0.25
 | 
			
		||||
        overrides['conf'] = 0.25
 | 
			
		||||
        overrides.update(kwargs)
 | 
			
		||||
        overrides["mode"] = kwargs.get("mode", "predict")
 | 
			
		||||
        assert overrides["mode"] in ['track', 'predict']
 | 
			
		||||
        overrides["save"] = kwargs.get("save", False)  # not save files by default
 | 
			
		||||
        overrides['mode'] = kwargs.get('mode', 'predict')
 | 
			
		||||
        assert overrides['mode'] in ['track', 'predict']
 | 
			
		||||
        overrides['save'] = kwargs.get('save', False)  # not save files by default
 | 
			
		||||
        if not self.predictor:
 | 
			
		||||
            self.predictor = self.PredictorClass(overrides=overrides)
 | 
			
		||||
            self.predictor.setup_model(model=self.model)
 | 
			
		||||
@ -173,7 +173,7 @@ class YOLO:
 | 
			
		||||
        from ultralytics.tracker.track import register_tracker
 | 
			
		||||
        register_tracker(self)
 | 
			
		||||
        # bytetrack-based method needs low confidence predictions as input
 | 
			
		||||
        conf = kwargs.get("conf") or 0.1
 | 
			
		||||
        conf = kwargs.get('conf') or 0.1
 | 
			
		||||
        kwargs['conf'] = conf
 | 
			
		||||
        kwargs['mode'] = 'track'
 | 
			
		||||
        return self.predict(source=source, stream=stream, **kwargs)
 | 
			
		||||
@ -188,9 +188,9 @@ class YOLO:
 | 
			
		||||
            **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
 | 
			
		||||
        """
 | 
			
		||||
        overrides = self.overrides.copy()
 | 
			
		||||
        overrides["rect"] = True  # rect batches as default
 | 
			
		||||
        overrides['rect'] = True  # rect batches as default
 | 
			
		||||
        overrides.update(kwargs)
 | 
			
		||||
        overrides["mode"] = "val"
 | 
			
		||||
        overrides['mode'] = 'val'
 | 
			
		||||
        args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
 | 
			
		||||
        args.data = data or args.data
 | 
			
		||||
        args.task = self.task
 | 
			
		||||
@ -234,18 +234,18 @@ class YOLO:
 | 
			
		||||
        self._check_is_pytorch_model()
 | 
			
		||||
        overrides = self.overrides.copy()
 | 
			
		||||
        overrides.update(kwargs)
 | 
			
		||||
        if kwargs.get("cfg"):
 | 
			
		||||
        if kwargs.get('cfg'):
 | 
			
		||||
            LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
 | 
			
		||||
            overrides = yaml_load(check_yaml(kwargs["cfg"]), append_filename=True)
 | 
			
		||||
        overrides["task"] = self.task
 | 
			
		||||
        overrides["mode"] = "train"
 | 
			
		||||
        if not overrides.get("data"):
 | 
			
		||||
            overrides = yaml_load(check_yaml(kwargs['cfg']), append_filename=True)
 | 
			
		||||
        overrides['task'] = self.task
 | 
			
		||||
        overrides['mode'] = 'train'
 | 
			
		||||
        if not overrides.get('data'):
 | 
			
		||||
            raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
 | 
			
		||||
        if overrides.get("resume"):
 | 
			
		||||
            overrides["resume"] = self.ckpt_path
 | 
			
		||||
        if overrides.get('resume'):
 | 
			
		||||
            overrides['resume'] = self.ckpt_path
 | 
			
		||||
 | 
			
		||||
        self.trainer = self.TrainerClass(overrides=overrides)
 | 
			
		||||
        if not overrides.get("resume"):  # manually set model only if not resuming
 | 
			
		||||
        if not overrides.get('resume'):  # manually set model only if not resuming
 | 
			
		||||
            self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
 | 
			
		||||
            self.model = self.trainer.model
 | 
			
		||||
        self.trainer.train()
 | 
			
		||||
@ -267,9 +267,9 @@ class YOLO:
 | 
			
		||||
 | 
			
		||||
    def _assign_ops_from_task(self):
 | 
			
		||||
        model_class, train_lit, val_lit, pred_lit = MODEL_MAP[self.task]
 | 
			
		||||
        trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
 | 
			
		||||
        validator_class = eval(val_lit.replace("TYPE", f"{self.type}"))
 | 
			
		||||
        predictor_class = eval(pred_lit.replace("TYPE", f"{self.type}"))
 | 
			
		||||
        trainer_class = eval(train_lit.replace('TYPE', f'{self.type}'))
 | 
			
		||||
        validator_class = eval(val_lit.replace('TYPE', f'{self.type}'))
 | 
			
		||||
        predictor_class = eval(pred_lit.replace('TYPE', f'{self.type}'))
 | 
			
		||||
        return model_class, trainer_class, validator_class, predictor_class
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
@ -292,7 +292,7 @@ class YOLO:
 | 
			
		||||
        Returns metrics if computed
 | 
			
		||||
        """
 | 
			
		||||
        if not self.metrics_data:
 | 
			
		||||
            LOGGER.info("No metrics data found! Run training or validation operation first.")
 | 
			
		||||
            LOGGER.info('No metrics data found! Run training or validation operation first.')
 | 
			
		||||
 | 
			
		||||
        return self.metrics_data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -72,7 +72,7 @@ class BasePredictor:
 | 
			
		||||
        """
 | 
			
		||||
        self.args = get_cfg(cfg, overrides)
 | 
			
		||||
        project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
 | 
			
		||||
        name = self.args.name or f"{self.args.mode}"
 | 
			
		||||
        name = self.args.name or f'{self.args.mode}'
 | 
			
		||||
        self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
 | 
			
		||||
        if self.args.conf is None:
 | 
			
		||||
            self.args.conf = 0.25  # default conf=0.25
 | 
			
		||||
@ -97,10 +97,10 @@ class BasePredictor:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def get_annotator(self, img):
 | 
			
		||||
        raise NotImplementedError("get_annotator function needs to be implemented")
 | 
			
		||||
        raise NotImplementedError('get_annotator function needs to be implemented')
 | 
			
		||||
 | 
			
		||||
    def write_results(self, results, batch, print_string):
 | 
			
		||||
        raise NotImplementedError("print_results function needs to be implemented")
 | 
			
		||||
        raise NotImplementedError('print_results function needs to be implemented')
 | 
			
		||||
 | 
			
		||||
    def postprocess(self, preds, img, orig_img):
 | 
			
		||||
        return preds
 | 
			
		||||
@ -135,7 +135,7 @@ class BasePredictor:
 | 
			
		||||
 | 
			
		||||
    def stream_inference(self, source=None, model=None):
 | 
			
		||||
        if self.args.verbose:
 | 
			
		||||
            LOGGER.info("")
 | 
			
		||||
            LOGGER.info('')
 | 
			
		||||
 | 
			
		||||
        # setup model
 | 
			
		||||
        if not self.model:
 | 
			
		||||
@ -152,9 +152,9 @@ class BasePredictor:
 | 
			
		||||
            self.done_warmup = True
 | 
			
		||||
 | 
			
		||||
        self.seen, self.windows, self.dt, self.batch = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()), None
 | 
			
		||||
        self.run_callbacks("on_predict_start")
 | 
			
		||||
        self.run_callbacks('on_predict_start')
 | 
			
		||||
        for batch in self.dataset:
 | 
			
		||||
            self.run_callbacks("on_predict_batch_start")
 | 
			
		||||
            self.run_callbacks('on_predict_batch_start')
 | 
			
		||||
            self.batch = batch
 | 
			
		||||
            path, im, im0s, vid_cap, s = batch
 | 
			
		||||
            visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
 | 
			
		||||
@ -170,7 +170,7 @@ class BasePredictor:
 | 
			
		||||
            # postprocess
 | 
			
		||||
            with self.dt[2]:
 | 
			
		||||
                self.results = self.postprocess(preds, im, im0s)
 | 
			
		||||
            self.run_callbacks("on_predict_postprocess_end")
 | 
			
		||||
            self.run_callbacks('on_predict_postprocess_end')
 | 
			
		||||
 | 
			
		||||
            # visualize, save, write results
 | 
			
		||||
            for i in range(len(im)):
 | 
			
		||||
@ -186,7 +186,7 @@ class BasePredictor:
 | 
			
		||||
 | 
			
		||||
                if self.args.save:
 | 
			
		||||
                    self.save_preds(vid_cap, i, str(self.save_dir / p.name))
 | 
			
		||||
            self.run_callbacks("on_predict_batch_end")
 | 
			
		||||
            self.run_callbacks('on_predict_batch_end')
 | 
			
		||||
            yield from self.results
 | 
			
		||||
 | 
			
		||||
            # Print time (inference-only)
 | 
			
		||||
@ -207,7 +207,7 @@ class BasePredictor:
 | 
			
		||||
            s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
 | 
			
		||||
            LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
 | 
			
		||||
 | 
			
		||||
        self.run_callbacks("on_predict_end")
 | 
			
		||||
        self.run_callbacks('on_predict_end')
 | 
			
		||||
 | 
			
		||||
    def setup_model(self, model):
 | 
			
		||||
        device = select_device(self.args.device)
 | 
			
		||||
 | 
			
		||||
@ -36,7 +36,7 @@ class Results:
 | 
			
		||||
        self.masks = Masks(masks, self.orig_shape) if masks is not None else None  # native size or imgsz masks
 | 
			
		||||
        self.probs = probs if probs is not None else None
 | 
			
		||||
        self.names = names
 | 
			
		||||
        self.comp = ["boxes", "masks", "probs"]
 | 
			
		||||
        self.comp = ['boxes', 'masks', 'probs']
 | 
			
		||||
 | 
			
		||||
    def pandas(self):
 | 
			
		||||
        pass
 | 
			
		||||
@ -97,7 +97,7 @@ class Results:
 | 
			
		||||
            return len(getattr(self, item))
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        str_out = ""
 | 
			
		||||
        str_out = ''
 | 
			
		||||
        for item in self.comp:
 | 
			
		||||
            if getattr(self, item) is None:
 | 
			
		||||
                continue
 | 
			
		||||
@ -105,7 +105,7 @@ class Results:
 | 
			
		||||
        return str_out
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        str_out = ""
 | 
			
		||||
        str_out = ''
 | 
			
		||||
        for item in self.comp:
 | 
			
		||||
            if getattr(self, item) is None:
 | 
			
		||||
                continue
 | 
			
		||||
@ -187,7 +187,7 @@ class Boxes:
 | 
			
		||||
        if boxes.ndim == 1:
 | 
			
		||||
            boxes = boxes[None, :]
 | 
			
		||||
        n = boxes.shape[-1]
 | 
			
		||||
        assert n in {6, 7}, f"expected `n` in [6, 7], but got {n}"  # xyxy, (track_id), conf, cls
 | 
			
		||||
        assert n in {6, 7}, f'expected `n` in [6, 7], but got {n}'  # xyxy, (track_id), conf, cls
 | 
			
		||||
        # TODO
 | 
			
		||||
        self.is_track = n == 7
 | 
			
		||||
        self.boxes = boxes
 | 
			
		||||
@ -268,8 +268,8 @@ class Boxes:
 | 
			
		||||
        return self.boxes.__str__()
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return (f"Ultralytics YOLO {self.__class__} masks\n" + f"type: {type(self.boxes)}\n" +
 | 
			
		||||
                f"shape: {self.boxes.shape}\n" + f"dtype: {self.boxes.dtype}\n + {self.boxes.__repr__()}")
 | 
			
		||||
        return (f'Ultralytics YOLO {self.__class__} masks\n' + f'type: {type(self.boxes)}\n' +
 | 
			
		||||
                f'shape: {self.boxes.shape}\n' + f'dtype: {self.boxes.dtype}\n + {self.boxes.__repr__()}')
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, idx):
 | 
			
		||||
        boxes = self.boxes[idx]
 | 
			
		||||
@ -353,8 +353,8 @@ class Masks:
 | 
			
		||||
        return self.masks.__str__()
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return (f"Ultralytics YOLO {self.__class__} masks\n" + f"type: {type(self.masks)}\n" +
 | 
			
		||||
                f"shape: {self.masks.shape}\n" + f"dtype: {self.masks.dtype}\n + {self.masks.__repr__()}")
 | 
			
		||||
        return (f'Ultralytics YOLO {self.__class__} masks\n' + f'type: {type(self.masks)}\n' +
 | 
			
		||||
                f'shape: {self.masks.shape}\n' + f'dtype: {self.masks.dtype}\n + {self.masks.__repr__()}')
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, idx):
 | 
			
		||||
        masks = self.masks[idx]
 | 
			
		||||
@ -374,19 +374,19 @@ class Masks:
 | 
			
		||||
            """)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    # test examples
 | 
			
		||||
    results = Results(boxes=torch.randn((2, 6)), masks=torch.randn((2, 160, 160)), orig_shape=[640, 640])
 | 
			
		||||
    results = results.cuda()
 | 
			
		||||
    print("--cuda--pass--")
 | 
			
		||||
    print('--cuda--pass--')
 | 
			
		||||
    results = results.cpu()
 | 
			
		||||
    print("--cpu--pass--")
 | 
			
		||||
    results = results.to("cuda:0")
 | 
			
		||||
    print("--to-cuda--pass--")
 | 
			
		||||
    results = results.to("cpu")
 | 
			
		||||
    print("--to-cpu--pass--")
 | 
			
		||||
    print('--cpu--pass--')
 | 
			
		||||
    results = results.to('cuda:0')
 | 
			
		||||
    print('--to-cuda--pass--')
 | 
			
		||||
    results = results.to('cpu')
 | 
			
		||||
    print('--to-cpu--pass--')
 | 
			
		||||
    results = results.numpy()
 | 
			
		||||
    print("--numpy--pass--")
 | 
			
		||||
    print('--numpy--pass--')
 | 
			
		||||
    # box = Boxes(boxes=torch.randn((2, 6)), orig_shape=[5, 5])
 | 
			
		||||
    # box = box.cuda()
 | 
			
		||||
    # box = box.cpu()
 | 
			
		||||
 | 
			
		||||
@ -90,7 +90,7 @@ class BaseTrainer:
 | 
			
		||||
 | 
			
		||||
        # Dirs
 | 
			
		||||
        project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
 | 
			
		||||
        name = self.args.name or f"{self.args.mode}"
 | 
			
		||||
        name = self.args.name or f'{self.args.mode}'
 | 
			
		||||
        if hasattr(self.args, 'save_dir'):
 | 
			
		||||
            self.save_dir = Path(self.args.save_dir)
 | 
			
		||||
        else:
 | 
			
		||||
@ -121,7 +121,7 @@ class BaseTrainer:
 | 
			
		||||
        try:
 | 
			
		||||
            if self.args.task == 'classify':
 | 
			
		||||
                self.data = check_cls_dataset(self.args.data)
 | 
			
		||||
            elif self.args.data.endswith(".yaml") or self.args.task in ('detect', 'segment'):
 | 
			
		||||
            elif self.args.data.endswith('.yaml') or self.args.task in ('detect', 'segment'):
 | 
			
		||||
                self.data = check_det_dataset(self.args.data)
 | 
			
		||||
                if 'yaml_file' in self.data:
 | 
			
		||||
                    self.args.data = self.data['yaml_file']  # for validating 'yolo train data=url.zip' usage
 | 
			
		||||
@ -175,7 +175,7 @@ class BaseTrainer:
 | 
			
		||||
            world_size = 0
 | 
			
		||||
 | 
			
		||||
        # Run subprocess if DDP training, else train normally
 | 
			
		||||
        if world_size > 1 and "LOCAL_RANK" not in os.environ:
 | 
			
		||||
        if world_size > 1 and 'LOCAL_RANK' not in os.environ:
 | 
			
		||||
            cmd, file = generate_ddp_command(world_size, self)  # security vulnerability in Snyk scans
 | 
			
		||||
            try:
 | 
			
		||||
                subprocess.run(cmd, check=True)
 | 
			
		||||
@ -191,15 +191,15 @@ class BaseTrainer:
 | 
			
		||||
        # os.environ['MASTER_PORT'] = '9020'
 | 
			
		||||
        torch.cuda.set_device(rank)
 | 
			
		||||
        self.device = torch.device('cuda', rank)
 | 
			
		||||
        self.console.info(f"DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}")
 | 
			
		||||
        dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
 | 
			
		||||
        self.console.info(f'DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}')
 | 
			
		||||
        dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo', rank=rank, world_size=world_size)
 | 
			
		||||
 | 
			
		||||
    def _setup_train(self, rank, world_size):
 | 
			
		||||
        """
 | 
			
		||||
        Builds dataloaders and optimizer on correct rank process.
 | 
			
		||||
        """
 | 
			
		||||
        # model
 | 
			
		||||
        self.run_callbacks("on_pretrain_routine_start")
 | 
			
		||||
        self.run_callbacks('on_pretrain_routine_start')
 | 
			
		||||
        ckpt = self.setup_model()
 | 
			
		||||
        self.model = self.model.to(self.device)
 | 
			
		||||
        self.set_model_attributes()
 | 
			
		||||
@ -234,16 +234,16 @@ class BaseTrainer:
 | 
			
		||||
 | 
			
		||||
        # dataloaders
 | 
			
		||||
        batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size
 | 
			
		||||
        self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode="train")
 | 
			
		||||
        self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode='train')
 | 
			
		||||
        if rank in {0, -1}:
 | 
			
		||||
            self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
 | 
			
		||||
            self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
 | 
			
		||||
            self.validator = self.get_validator()
 | 
			
		||||
            metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
 | 
			
		||||
            metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
 | 
			
		||||
            self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))  # TODO: init metrics for plot_results()?
 | 
			
		||||
            self.ema = ModelEMA(self.model)
 | 
			
		||||
        self.resume_training(ckpt)
 | 
			
		||||
        self.scheduler.last_epoch = self.start_epoch - 1  # do not move
 | 
			
		||||
        self.run_callbacks("on_pretrain_routine_end")
 | 
			
		||||
        self.run_callbacks('on_pretrain_routine_end')
 | 
			
		||||
 | 
			
		||||
    def _do_train(self, rank=-1, world_size=1):
 | 
			
		||||
        if world_size > 1:
 | 
			
		||||
@ -257,24 +257,24 @@ class BaseTrainer:
 | 
			
		||||
        nb = len(self.train_loader)  # number of batches
 | 
			
		||||
        nw = max(round(self.args.warmup_epochs * nb), 100)  # number of warmup iterations
 | 
			
		||||
        last_opt_step = -1
 | 
			
		||||
        self.run_callbacks("on_train_start")
 | 
			
		||||
        self.log(f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
 | 
			
		||||
        self.run_callbacks('on_train_start')
 | 
			
		||||
        self.log(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
 | 
			
		||||
                 f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
 | 
			
		||||
                 f"Logging results to {colorstr('bold', self.save_dir)}\n"
 | 
			
		||||
                 f"Starting training for {self.epochs} epochs...")
 | 
			
		||||
                 f'Starting training for {self.epochs} epochs...')
 | 
			
		||||
        if self.args.close_mosaic:
 | 
			
		||||
            base_idx = (self.epochs - self.args.close_mosaic) * nb
 | 
			
		||||
            self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
 | 
			
		||||
        for epoch in range(self.start_epoch, self.epochs):
 | 
			
		||||
            self.epoch = epoch
 | 
			
		||||
            self.run_callbacks("on_train_epoch_start")
 | 
			
		||||
            self.run_callbacks('on_train_epoch_start')
 | 
			
		||||
            self.model.train()
 | 
			
		||||
            if rank != -1:
 | 
			
		||||
                self.train_loader.sampler.set_epoch(epoch)
 | 
			
		||||
            pbar = enumerate(self.train_loader)
 | 
			
		||||
            # Update dataloader attributes (optional)
 | 
			
		||||
            if epoch == (self.epochs - self.args.close_mosaic):
 | 
			
		||||
                self.console.info("Closing dataloader mosaic")
 | 
			
		||||
                self.console.info('Closing dataloader mosaic')
 | 
			
		||||
                if hasattr(self.train_loader.dataset, 'mosaic'):
 | 
			
		||||
                    self.train_loader.dataset.mosaic = False
 | 
			
		||||
                if hasattr(self.train_loader.dataset, 'close_mosaic'):
 | 
			
		||||
@ -286,7 +286,7 @@ class BaseTrainer:
 | 
			
		||||
            self.tloss = None
 | 
			
		||||
            self.optimizer.zero_grad()
 | 
			
		||||
            for i, batch in pbar:
 | 
			
		||||
                self.run_callbacks("on_train_batch_start")
 | 
			
		||||
                self.run_callbacks('on_train_batch_start')
 | 
			
		||||
                # Warmup
 | 
			
		||||
                ni = i + nb * epoch
 | 
			
		||||
                if ni <= nw:
 | 
			
		||||
@ -302,7 +302,7 @@ class BaseTrainer:
 | 
			
		||||
                # Forward
 | 
			
		||||
                with torch.cuda.amp.autocast(self.amp):
 | 
			
		||||
                    batch = self.preprocess_batch(batch)
 | 
			
		||||
                    preds = self.model(batch["img"])
 | 
			
		||||
                    preds = self.model(batch['img'])
 | 
			
		||||
                    self.loss, self.loss_items = self.criterion(preds, batch)
 | 
			
		||||
                    if rank != -1:
 | 
			
		||||
                        self.loss *= world_size
 | 
			
		||||
@ -324,17 +324,17 @@ class BaseTrainer:
 | 
			
		||||
                if rank in {-1, 0}:
 | 
			
		||||
                    pbar.set_description(
 | 
			
		||||
                        ('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
 | 
			
		||||
                        (f'{epoch + 1}/{self.epochs}', mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]))
 | 
			
		||||
                        (f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1]))
 | 
			
		||||
                    self.run_callbacks('on_batch_end')
 | 
			
		||||
                    if self.args.plots and ni in self.plot_idx:
 | 
			
		||||
                        self.plot_training_samples(batch, ni)
 | 
			
		||||
 | 
			
		||||
                self.run_callbacks("on_train_batch_end")
 | 
			
		||||
                self.run_callbacks('on_train_batch_end')
 | 
			
		||||
 | 
			
		||||
            self.lr = {f"lr/pg{ir}": x['lr'] for ir, x in enumerate(self.optimizer.param_groups)}  # for loggers
 | 
			
		||||
            self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)}  # for loggers
 | 
			
		||||
 | 
			
		||||
            self.scheduler.step()
 | 
			
		||||
            self.run_callbacks("on_train_epoch_end")
 | 
			
		||||
            self.run_callbacks('on_train_epoch_end')
 | 
			
		||||
 | 
			
		||||
            if rank in {-1, 0}:
 | 
			
		||||
 | 
			
		||||
@ -355,7 +355,7 @@ class BaseTrainer:
 | 
			
		||||
            tnow = time.time()
 | 
			
		||||
            self.epoch_time = tnow - self.epoch_time_start
 | 
			
		||||
            self.epoch_time_start = tnow
 | 
			
		||||
            self.run_callbacks("on_fit_epoch_end")
 | 
			
		||||
            self.run_callbacks('on_fit_epoch_end')
 | 
			
		||||
 | 
			
		||||
            # Early Stopping
 | 
			
		||||
            if RANK != -1:  # if DDP training
 | 
			
		||||
@ -402,7 +402,7 @@ class BaseTrainer:
 | 
			
		||||
        """
 | 
			
		||||
        Get train, val path from data dict if it exists. Returns None if data format is not recognized.
 | 
			
		||||
        """
 | 
			
		||||
        return data["train"], data.get("val") or data.get("test")
 | 
			
		||||
        return data['train'], data.get('val') or data.get('test')
 | 
			
		||||
 | 
			
		||||
    def setup_model(self):
 | 
			
		||||
        """
 | 
			
		||||
@ -413,9 +413,9 @@ class BaseTrainer:
 | 
			
		||||
 | 
			
		||||
        model, weights = self.model, None
 | 
			
		||||
        ckpt = None
 | 
			
		||||
        if str(model).endswith(".pt"):
 | 
			
		||||
        if str(model).endswith('.pt'):
 | 
			
		||||
            weights, ckpt = attempt_load_one_weight(model)
 | 
			
		||||
            cfg = ckpt["model"].yaml
 | 
			
		||||
            cfg = ckpt['model'].yaml
 | 
			
		||||
        else:
 | 
			
		||||
            cfg = model
 | 
			
		||||
        self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)  # calls Model(cfg, weights)
 | 
			
		||||
@ -441,7 +441,7 @@ class BaseTrainer:
 | 
			
		||||
        Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
 | 
			
		||||
        """
 | 
			
		||||
        metrics = self.validator(self)
 | 
			
		||||
        fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy())  # use loss as fitness measure if not found
 | 
			
		||||
        fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy())  # use loss as fitness measure if not found
 | 
			
		||||
        if not self.best_fitness or self.best_fitness < fitness:
 | 
			
		||||
            self.best_fitness = fitness
 | 
			
		||||
        return metrics, fitness
 | 
			
		||||
@ -462,38 +462,38 @@ class BaseTrainer:
 | 
			
		||||
        raise NotImplementedError("This task trainer doesn't support loading cfg files")
 | 
			
		||||
 | 
			
		||||
    def get_validator(self):
 | 
			
		||||
        raise NotImplementedError("get_validator function not implemented in trainer")
 | 
			
		||||
        raise NotImplementedError('get_validator function not implemented in trainer')
 | 
			
		||||
 | 
			
		||||
    def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
 | 
			
		||||
    def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
 | 
			
		||||
        """
 | 
			
		||||
        Returns dataloader derived from torch.data.Dataloader.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError("get_dataloader function not implemented in trainer")
 | 
			
		||||
        raise NotImplementedError('get_dataloader function not implemented in trainer')
 | 
			
		||||
 | 
			
		||||
    def criterion(self, preds, batch):
 | 
			
		||||
        """
 | 
			
		||||
        Returns loss and individual loss items as Tensor.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError("criterion function not implemented in trainer")
 | 
			
		||||
        raise NotImplementedError('criterion function not implemented in trainer')
 | 
			
		||||
 | 
			
		||||
    def label_loss_items(self, loss_items=None, prefix="train"):
 | 
			
		||||
    def label_loss_items(self, loss_items=None, prefix='train'):
 | 
			
		||||
        """
 | 
			
		||||
        Returns a loss dict with labelled training loss items tensor
 | 
			
		||||
        """
 | 
			
		||||
        # Not needed for classification but necessary for segmentation & detection
 | 
			
		||||
        return {"loss": loss_items} if loss_items is not None else ["loss"]
 | 
			
		||||
        return {'loss': loss_items} if loss_items is not None else ['loss']
 | 
			
		||||
 | 
			
		||||
    def set_model_attributes(self):
 | 
			
		||||
        """
 | 
			
		||||
        To set or update model parameters before training.
 | 
			
		||||
        """
 | 
			
		||||
        self.model.names = self.data["names"]
 | 
			
		||||
        self.model.names = self.data['names']
 | 
			
		||||
 | 
			
		||||
    def build_targets(self, preds, targets):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def progress_string(self):
 | 
			
		||||
        return ""
 | 
			
		||||
        return ''
 | 
			
		||||
 | 
			
		||||
    # TODO: may need to put these following functions into callback
 | 
			
		||||
    def plot_training_samples(self, batch, ni):
 | 
			
		||||
@ -529,7 +529,7 @@ class BaseTrainer:
 | 
			
		||||
                self.args = get_cfg(attempt_load_weights(last).args)
 | 
			
		||||
                self.args.model, resume = str(last), True  # reinstate
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                raise FileNotFoundError("Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
 | 
			
		||||
                raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '
 | 
			
		||||
                                        "i.e. 'yolo train resume model=path/to/last.pt'") from e
 | 
			
		||||
        self.resume = resume
 | 
			
		||||
 | 
			
		||||
@ -557,7 +557,7 @@ class BaseTrainer:
 | 
			
		||||
        self.best_fitness = best_fitness
 | 
			
		||||
        self.start_epoch = start_epoch
 | 
			
		||||
        if start_epoch > (self.epochs - self.args.close_mosaic):
 | 
			
		||||
            self.console.info("Closing dataloader mosaic")
 | 
			
		||||
            self.console.info('Closing dataloader mosaic')
 | 
			
		||||
            if hasattr(self.train_loader.dataset, 'mosaic'):
 | 
			
		||||
                self.train_loader.dataset.mosaic = False
 | 
			
		||||
            if hasattr(self.train_loader.dataset, 'close_mosaic'):
 | 
			
		||||
@ -602,5 +602,5 @@ class BaseTrainer:
 | 
			
		||||
        optimizer.add_param_group({'params': g[0], 'weight_decay': decay})  # add g0 with weight_decay
 | 
			
		||||
        optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0})  # add g1 (BatchNorm2d weights)
 | 
			
		||||
        LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
 | 
			
		||||
                    f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias")
 | 
			
		||||
                    f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias')
 | 
			
		||||
        return optimizer
 | 
			
		||||
 | 
			
		||||
@ -62,7 +62,7 @@ class BaseValidator:
 | 
			
		||||
        self.jdict = None
 | 
			
		||||
 | 
			
		||||
        project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
 | 
			
		||||
        name = self.args.name or f"{self.args.mode}"
 | 
			
		||||
        name = self.args.name or f'{self.args.mode}'
 | 
			
		||||
        self.save_dir = save_dir or increment_path(Path(project) / name,
 | 
			
		||||
                                                   exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)
 | 
			
		||||
        (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
 | 
			
		||||
@ -92,7 +92,7 @@ class BaseValidator:
 | 
			
		||||
        else:
 | 
			
		||||
            callbacks.add_integration_callbacks(self)
 | 
			
		||||
            self.run_callbacks('on_val_start')
 | 
			
		||||
            assert model is not None, "Either trainer or model is needed for validation"
 | 
			
		||||
            assert model is not None, 'Either trainer or model is needed for validation'
 | 
			
		||||
            self.device = select_device(self.args.device, self.args.batch)
 | 
			
		||||
            self.args.half &= self.device.type != 'cpu'
 | 
			
		||||
            model = AutoBackend(model, device=self.device, dnn=self.args.dnn, data=self.args.data, fp16=self.args.half)
 | 
			
		||||
@ -108,7 +108,7 @@ class BaseValidator:
 | 
			
		||||
                    self.logger.info(
 | 
			
		||||
                        f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
 | 
			
		||||
 | 
			
		||||
            if isinstance(self.args.data, str) and self.args.data.endswith(".yaml"):
 | 
			
		||||
            if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'):
 | 
			
		||||
                self.data = check_det_dataset(self.args.data)
 | 
			
		||||
            elif self.args.task == 'classify':
 | 
			
		||||
                self.data = check_cls_dataset(self.args.data)
 | 
			
		||||
@ -142,7 +142,7 @@ class BaseValidator:
 | 
			
		||||
 | 
			
		||||
            # inference
 | 
			
		||||
            with dt[1]:
 | 
			
		||||
                preds = model(batch["img"])
 | 
			
		||||
                preds = model(batch['img'])
 | 
			
		||||
 | 
			
		||||
            # loss
 | 
			
		||||
            with dt[2]:
 | 
			
		||||
@ -166,14 +166,14 @@ class BaseValidator:
 | 
			
		||||
        self.run_callbacks('on_val_end')
 | 
			
		||||
        if self.training:
 | 
			
		||||
            model.float()
 | 
			
		||||
            results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
 | 
			
		||||
            results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
 | 
			
		||||
            return {k: round(float(v), 5) for k, v in results.items()}  # return results as 5 decimal place floats
 | 
			
		||||
        else:
 | 
			
		||||
            self.logger.info('Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image' %
 | 
			
		||||
                             self.speed)
 | 
			
		||||
            if self.args.save_json and self.jdict:
 | 
			
		||||
                with open(str(self.save_dir / "predictions.json"), 'w') as f:
 | 
			
		||||
                    self.logger.info(f"Saving {f.name}...")
 | 
			
		||||
                with open(str(self.save_dir / 'predictions.json'), 'w') as f:
 | 
			
		||||
                    self.logger.info(f'Saving {f.name}...')
 | 
			
		||||
                    json.dump(self.jdict, f)  # flatten and save
 | 
			
		||||
                stats = self.eval_json(stats)  # update stats
 | 
			
		||||
            return stats
 | 
			
		||||
@ -183,7 +183,7 @@ class BaseValidator:
 | 
			
		||||
            callback(self)
 | 
			
		||||
 | 
			
		||||
    def get_dataloader(self, dataset_path, batch_size):
 | 
			
		||||
        raise NotImplementedError("get_dataloader function not implemented for this validator")
 | 
			
		||||
        raise NotImplementedError('get_dataloader function not implemented for this validator')
 | 
			
		||||
 | 
			
		||||
    def preprocess(self, batch):
 | 
			
		||||
        return batch
 | 
			
		||||
 | 
			
		||||
@ -27,7 +27,7 @@ from ultralytics import __version__
 | 
			
		||||
# Constants
 | 
			
		||||
FILE = Path(__file__).resolve()
 | 
			
		||||
ROOT = FILE.parents[2]  # YOLO
 | 
			
		||||
DEFAULT_CFG_PATH = ROOT / "yolo/cfg/default.yaml"
 | 
			
		||||
DEFAULT_CFG_PATH = ROOT / 'yolo/cfg/default.yaml'
 | 
			
		||||
RANK = int(os.getenv('RANK', -1))
 | 
			
		||||
NUM_THREADS = min(8, max(1, os.cpu_count() - 1))  # number of YOLOv5 multiprocessing threads
 | 
			
		||||
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true'  # global auto-install mode
 | 
			
		||||
@ -111,7 +111,7 @@ class IterableSimpleNamespace(SimpleNamespace):
 | 
			
		||||
        return iter(vars(self).items())
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return '\n'.join(f"{k}={v}" for k, v in vars(self).items())
 | 
			
		||||
        return '\n'.join(f'{k}={v}' for k, v in vars(self).items())
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, attr):
 | 
			
		||||
        name = self.__class__.__name__
 | 
			
		||||
@ -288,7 +288,7 @@ def is_pytest_running():
 | 
			
		||||
        (bool): True if pytest is running, False otherwise.
 | 
			
		||||
    """
 | 
			
		||||
    with contextlib.suppress(Exception):
 | 
			
		||||
        return "pytest" in sys.modules
 | 
			
		||||
        return 'pytest' in sys.modules
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -336,7 +336,7 @@ def get_git_origin_url():
 | 
			
		||||
    """
 | 
			
		||||
    if is_git_dir():
 | 
			
		||||
        with contextlib.suppress(subprocess.CalledProcessError):
 | 
			
		||||
            origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"])
 | 
			
		||||
            origin = subprocess.check_output(['git', 'config', '--get', 'remote.origin.url'])
 | 
			
		||||
            return origin.decode().strip()
 | 
			
		||||
    return None  # if not git dir or on error
 | 
			
		||||
 | 
			
		||||
@ -350,7 +350,7 @@ def get_git_branch():
 | 
			
		||||
    """
 | 
			
		||||
    if is_git_dir():
 | 
			
		||||
        with contextlib.suppress(subprocess.CalledProcessError):
 | 
			
		||||
            origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
 | 
			
		||||
            origin = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
 | 
			
		||||
            return origin.decode().strip()
 | 
			
		||||
    return None  # if not git dir or on error
 | 
			
		||||
 | 
			
		||||
@ -365,9 +365,9 @@ def get_latest_pypi_version(package_name='ultralytics'):
 | 
			
		||||
    Returns:
 | 
			
		||||
        str: The latest version of the package.
 | 
			
		||||
    """
 | 
			
		||||
    response = requests.get(f"https://pypi.org/pypi/{package_name}/json")
 | 
			
		||||
    response = requests.get(f'https://pypi.org/pypi/{package_name}/json')
 | 
			
		||||
    if response.status_code == 200:
 | 
			
		||||
        return response.json()["info"]["version"]
 | 
			
		||||
        return response.json()['info']['version']
 | 
			
		||||
    return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -424,28 +424,28 @@ def emojis(string=''):
 | 
			
		||||
 | 
			
		||||
def colorstr(*input):
 | 
			
		||||
    # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e.  colorstr('blue', 'hello world')
 | 
			
		||||
    *args, string = input if len(input) > 1 else ("blue", "bold", input[0])  # color arguments, string
 | 
			
		||||
    *args, string = input if len(input) > 1 else ('blue', 'bold', input[0])  # color arguments, string
 | 
			
		||||
    colors = {
 | 
			
		||||
        "black": "\033[30m",  # basic colors
 | 
			
		||||
        "red": "\033[31m",
 | 
			
		||||
        "green": "\033[32m",
 | 
			
		||||
        "yellow": "\033[33m",
 | 
			
		||||
        "blue": "\033[34m",
 | 
			
		||||
        "magenta": "\033[35m",
 | 
			
		||||
        "cyan": "\033[36m",
 | 
			
		||||
        "white": "\033[37m",
 | 
			
		||||
        "bright_black": "\033[90m",  # bright colors
 | 
			
		||||
        "bright_red": "\033[91m",
 | 
			
		||||
        "bright_green": "\033[92m",
 | 
			
		||||
        "bright_yellow": "\033[93m",
 | 
			
		||||
        "bright_blue": "\033[94m",
 | 
			
		||||
        "bright_magenta": "\033[95m",
 | 
			
		||||
        "bright_cyan": "\033[96m",
 | 
			
		||||
        "bright_white": "\033[97m",
 | 
			
		||||
        "end": "\033[0m",  # misc
 | 
			
		||||
        "bold": "\033[1m",
 | 
			
		||||
        "underline": "\033[4m"}
 | 
			
		||||
    return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
 | 
			
		||||
        'black': '\033[30m',  # basic colors
 | 
			
		||||
        'red': '\033[31m',
 | 
			
		||||
        'green': '\033[32m',
 | 
			
		||||
        'yellow': '\033[33m',
 | 
			
		||||
        'blue': '\033[34m',
 | 
			
		||||
        'magenta': '\033[35m',
 | 
			
		||||
        'cyan': '\033[36m',
 | 
			
		||||
        'white': '\033[37m',
 | 
			
		||||
        'bright_black': '\033[90m',  # bright colors
 | 
			
		||||
        'bright_red': '\033[91m',
 | 
			
		||||
        'bright_green': '\033[92m',
 | 
			
		||||
        'bright_yellow': '\033[93m',
 | 
			
		||||
        'bright_blue': '\033[94m',
 | 
			
		||||
        'bright_magenta': '\033[95m',
 | 
			
		||||
        'bright_cyan': '\033[96m',
 | 
			
		||||
        'bright_white': '\033[97m',
 | 
			
		||||
        'end': '\033[0m',  # misc
 | 
			
		||||
        'bold': '\033[1m',
 | 
			
		||||
        'underline': '\033[4m'}
 | 
			
		||||
    return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def remove_ansi_codes(string):
 | 
			
		||||
@ -466,21 +466,21 @@ def set_logging(name=LOGGING_NAME, verbose=True):
 | 
			
		||||
    rank = int(os.getenv('RANK', -1))  # rank in world for Multi-GPU trainings
 | 
			
		||||
    level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
 | 
			
		||||
    logging.config.dictConfig({
 | 
			
		||||
        "version": 1,
 | 
			
		||||
        "disable_existing_loggers": False,
 | 
			
		||||
        "formatters": {
 | 
			
		||||
        'version': 1,
 | 
			
		||||
        'disable_existing_loggers': False,
 | 
			
		||||
        'formatters': {
 | 
			
		||||
            name: {
 | 
			
		||||
                "format": "%(message)s"}},
 | 
			
		||||
        "handlers": {
 | 
			
		||||
                'format': '%(message)s'}},
 | 
			
		||||
        'handlers': {
 | 
			
		||||
            name: {
 | 
			
		||||
                "class": "logging.StreamHandler",
 | 
			
		||||
                "formatter": name,
 | 
			
		||||
                "level": level}},
 | 
			
		||||
        "loggers": {
 | 
			
		||||
                'class': 'logging.StreamHandler',
 | 
			
		||||
                'formatter': name,
 | 
			
		||||
                'level': level}},
 | 
			
		||||
        'loggers': {
 | 
			
		||||
            name: {
 | 
			
		||||
                "level": level,
 | 
			
		||||
                "handlers": [name],
 | 
			
		||||
                "propagate": False}}})
 | 
			
		||||
                'level': level,
 | 
			
		||||
                'handlers': [name],
 | 
			
		||||
                'propagate': False}}})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TryExcept(contextlib.ContextDecorator):
 | 
			
		||||
@ -521,10 +521,10 @@ def set_sentry():
 | 
			
		||||
                return None  # do not send event
 | 
			
		||||
 | 
			
		||||
        event['tags'] = {
 | 
			
		||||
            "sys_argv": sys.argv[0],
 | 
			
		||||
            "sys_argv_name": Path(sys.argv[0]).name,
 | 
			
		||||
            "install": 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
 | 
			
		||||
            "os": ENVIRONMENT}
 | 
			
		||||
            'sys_argv': sys.argv[0],
 | 
			
		||||
            'sys_argv_name': Path(sys.argv[0]).name,
 | 
			
		||||
            'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
 | 
			
		||||
            'os': ENVIRONMENT}
 | 
			
		||||
        return event
 | 
			
		||||
 | 
			
		||||
    if SETTINGS['sync'] and \
 | 
			
		||||
@ -533,24 +533,24 @@ def set_sentry():
 | 
			
		||||
            not is_pytest_running() and \
 | 
			
		||||
            not is_github_actions_ci() and \
 | 
			
		||||
            ((is_pip_package() and not is_git_dir()) or
 | 
			
		||||
             (get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git" and get_git_branch() == "main")):
 | 
			
		||||
             (get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git' and get_git_branch() == 'main')):
 | 
			
		||||
 | 
			
		||||
        import hashlib
 | 
			
		||||
 | 
			
		||||
        import sentry_sdk  # noqa
 | 
			
		||||
 | 
			
		||||
        sentry_sdk.init(
 | 
			
		||||
            dsn="https://f805855f03bb4363bc1e16cb7d87b654@o4504521589325824.ingest.sentry.io/4504521592406016",
 | 
			
		||||
            dsn='https://f805855f03bb4363bc1e16cb7d87b654@o4504521589325824.ingest.sentry.io/4504521592406016',
 | 
			
		||||
            debug=False,
 | 
			
		||||
            traces_sample_rate=1.0,
 | 
			
		||||
            release=__version__,
 | 
			
		||||
            environment='production',  # 'dev' or 'production'
 | 
			
		||||
            before_send=before_send,
 | 
			
		||||
            ignore_errors=[KeyboardInterrupt, FileNotFoundError])
 | 
			
		||||
        sentry_sdk.set_user({"id": SETTINGS['uuid']})
 | 
			
		||||
        sentry_sdk.set_user({'id': SETTINGS['uuid']})
 | 
			
		||||
 | 
			
		||||
        # Disable all sentry logging
 | 
			
		||||
        for logger in "sentry_sdk", "sentry_sdk.errors":
 | 
			
		||||
        for logger in 'sentry_sdk', 'sentry_sdk.errors':
 | 
			
		||||
            logging.getLogger(logger).setLevel(logging.CRITICAL)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -620,7 +620,7 @@ if WINDOWS:
 | 
			
		||||
        setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x)))  # emoji safe logging
 | 
			
		||||
 | 
			
		||||
# Check first-install steps
 | 
			
		||||
PREFIX = colorstr("Ultralytics: ")
 | 
			
		||||
PREFIX = colorstr('Ultralytics: ')
 | 
			
		||||
SETTINGS = get_settings()
 | 
			
		||||
DATASETS_DIR = Path(SETTINGS['datasets_dir'])  # global datasets directory
 | 
			
		||||
ENVIRONMENT = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' if is_jupyter() else \
 | 
			
		||||
 | 
			
		||||
@ -11,7 +11,7 @@ except (ImportError, AssertionError):
 | 
			
		||||
    clearml = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _log_images(imgs_dict, group="", step=0):
 | 
			
		||||
def _log_images(imgs_dict, group='', step=0):
 | 
			
		||||
    task = Task.current_task()
 | 
			
		||||
    if task:
 | 
			
		||||
        for k, v in imgs_dict.items():
 | 
			
		||||
@ -20,7 +20,7 @@ def _log_images(imgs_dict, group="", step=0):
 | 
			
		||||
 | 
			
		||||
def on_pretrain_routine_start(trainer):
 | 
			
		||||
    # TODO: reuse existing task
 | 
			
		||||
    task = Task.init(project_name=trainer.args.project or "YOLOv8",
 | 
			
		||||
    task = Task.init(project_name=trainer.args.project or 'YOLOv8',
 | 
			
		||||
                     task_name=trainer.args.name,
 | 
			
		||||
                     tags=['YOLOv8'],
 | 
			
		||||
                     output_uri=True,
 | 
			
		||||
@ -31,15 +31,15 @@ def on_pretrain_routine_start(trainer):
 | 
			
		||||
 | 
			
		||||
def on_train_epoch_end(trainer):
 | 
			
		||||
    if trainer.epoch == 1:
 | 
			
		||||
        _log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, "Mosaic", trainer.epoch)
 | 
			
		||||
        _log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic', trainer.epoch)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def on_fit_epoch_end(trainer):
 | 
			
		||||
    if trainer.epoch == 0:
 | 
			
		||||
        model_info = {
 | 
			
		||||
            "Parameters": get_num_params(trainer.model),
 | 
			
		||||
            "GFLOPs": round(get_flops(trainer.model), 3),
 | 
			
		||||
            "Inference speed (ms/img)": round(trainer.validator.speed[1], 3)}
 | 
			
		||||
            'Parameters': get_num_params(trainer.model),
 | 
			
		||||
            'GFLOPs': round(get_flops(trainer.model), 3),
 | 
			
		||||
            'Inference speed (ms/img)': round(trainer.validator.speed[1], 3)}
 | 
			
		||||
        Task.current_task().connect(model_info, name='Model')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -50,7 +50,7 @@ def on_train_end(trainer):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
callbacks = {
 | 
			
		||||
    "on_pretrain_routine_start": on_pretrain_routine_start,
 | 
			
		||||
    "on_train_epoch_end": on_train_epoch_end,
 | 
			
		||||
    "on_fit_epoch_end": on_fit_epoch_end,
 | 
			
		||||
    "on_train_end": on_train_end} if clearml else {}
 | 
			
		||||
    'on_pretrain_routine_start': on_pretrain_routine_start,
 | 
			
		||||
    'on_train_epoch_end': on_train_epoch_end,
 | 
			
		||||
    'on_fit_epoch_end': on_fit_epoch_end,
 | 
			
		||||
    'on_train_end': on_train_end} if clearml else {}
 | 
			
		||||
 | 
			
		||||
@ -10,13 +10,13 @@ except ImportError:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def on_pretrain_routine_start(trainer):
 | 
			
		||||
    experiment = comet_ml.Experiment(project_name=trainer.args.project or "YOLOv8")
 | 
			
		||||
    experiment = comet_ml.Experiment(project_name=trainer.args.project or 'YOLOv8')
 | 
			
		||||
    experiment.log_parameters(vars(trainer.args))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def on_train_epoch_end(trainer):
 | 
			
		||||
    experiment = comet_ml.get_global_experiment()
 | 
			
		||||
    experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
 | 
			
		||||
    experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1)
 | 
			
		||||
    if trainer.epoch == 1:
 | 
			
		||||
        for f in trainer.save_dir.glob('train_batch*.jpg'):
 | 
			
		||||
            experiment.log_image(f, name=f.stem, step=trainer.epoch + 1)
 | 
			
		||||
@ -27,19 +27,19 @@ def on_fit_epoch_end(trainer):
 | 
			
		||||
    experiment.log_metrics(trainer.metrics, step=trainer.epoch + 1)
 | 
			
		||||
    if trainer.epoch == 0:
 | 
			
		||||
        model_info = {
 | 
			
		||||
            "model/parameters": get_num_params(trainer.model),
 | 
			
		||||
            "model/GFLOPs": round(get_flops(trainer.model), 3),
 | 
			
		||||
            "model/speed(ms)": round(trainer.validator.speed[1], 3)}
 | 
			
		||||
            'model/parameters': get_num_params(trainer.model),
 | 
			
		||||
            'model/GFLOPs': round(get_flops(trainer.model), 3),
 | 
			
		||||
            'model/speed(ms)': round(trainer.validator.speed[1], 3)}
 | 
			
		||||
        experiment.log_metrics(model_info, step=trainer.epoch + 1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def on_train_end(trainer):
 | 
			
		||||
    experiment = comet_ml.get_global_experiment()
 | 
			
		||||
    experiment.log_model("YOLOv8", file_or_folder=str(trainer.best), file_name="best.pt", overwrite=True)
 | 
			
		||||
    experiment.log_model('YOLOv8', file_or_folder=str(trainer.best), file_name='best.pt', overwrite=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
callbacks = {
 | 
			
		||||
    "on_pretrain_routine_start": on_pretrain_routine_start,
 | 
			
		||||
    "on_train_epoch_end": on_train_epoch_end,
 | 
			
		||||
    "on_fit_epoch_end": on_fit_epoch_end,
 | 
			
		||||
    "on_train_end": on_train_end} if comet_ml else {}
 | 
			
		||||
    'on_pretrain_routine_start': on_pretrain_routine_start,
 | 
			
		||||
    'on_train_epoch_end': on_train_epoch_end,
 | 
			
		||||
    'on_fit_epoch_end': on_fit_epoch_end,
 | 
			
		||||
    'on_train_end': on_train_end} if comet_ml else {}
 | 
			
		||||
 | 
			
		||||
@ -11,7 +11,7 @@ def on_pretrain_routine_end(trainer):
 | 
			
		||||
    session = getattr(trainer, 'hub_session', None)
 | 
			
		||||
    if session:
 | 
			
		||||
        # Start timer for upload rate limit
 | 
			
		||||
        LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀")
 | 
			
		||||
        LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀')
 | 
			
		||||
        session.t = {'metrics': time(), 'ckpt': time()}  # start timer on self.rate_limit
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -31,7 +31,7 @@ def on_model_save(trainer):
 | 
			
		||||
        # Upload checkpoints with rate limiting
 | 
			
		||||
        is_best = trainer.best_fitness == trainer.fitness
 | 
			
		||||
        if time() - session.t['ckpt'] > session.rate_limits['ckpt']:
 | 
			
		||||
            LOGGER.info(f"{PREFIX}Uploading checkpoint {session.model_id}")
 | 
			
		||||
            LOGGER.info(f'{PREFIX}Uploading checkpoint {session.model_id}')
 | 
			
		||||
            session.upload_model(trainer.epoch, trainer.last, is_best)
 | 
			
		||||
            session.t['ckpt'] = time()  # reset timer
 | 
			
		||||
 | 
			
		||||
@ -40,11 +40,11 @@ def on_train_end(trainer):
 | 
			
		||||
    session = getattr(trainer, 'hub_session', None)
 | 
			
		||||
    if session:
 | 
			
		||||
        # Upload final model and metrics with exponential standoff
 | 
			
		||||
        LOGGER.info(f"{PREFIX}Training completed successfully ✅\n"
 | 
			
		||||
                    f"{PREFIX}Uploading final {session.model_id}")
 | 
			
		||||
        LOGGER.info(f'{PREFIX}Training completed successfully ✅\n'
 | 
			
		||||
                    f'{PREFIX}Uploading final {session.model_id}')
 | 
			
		||||
        session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics['metrics/mAP50-95(B)'], final=True)
 | 
			
		||||
        session.shutdown()  # stop heartbeats
 | 
			
		||||
        LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀")
 | 
			
		||||
        LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def on_train_start(trainer):
 | 
			
		||||
@ -64,11 +64,11 @@ def on_export_start(exporter):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
callbacks = {
 | 
			
		||||
    "on_pretrain_routine_end": on_pretrain_routine_end,
 | 
			
		||||
    "on_fit_epoch_end": on_fit_epoch_end,
 | 
			
		||||
    "on_model_save": on_model_save,
 | 
			
		||||
    "on_train_end": on_train_end,
 | 
			
		||||
    "on_train_start": on_train_start,
 | 
			
		||||
    "on_val_start": on_val_start,
 | 
			
		||||
    "on_predict_start": on_predict_start,
 | 
			
		||||
    "on_export_start": on_export_start}
 | 
			
		||||
    'on_pretrain_routine_end': on_pretrain_routine_end,
 | 
			
		||||
    'on_fit_epoch_end': on_fit_epoch_end,
 | 
			
		||||
    'on_model_save': on_model_save,
 | 
			
		||||
    'on_train_end': on_train_end,
 | 
			
		||||
    'on_train_start': on_train_start,
 | 
			
		||||
    'on_val_start': on_val_start,
 | 
			
		||||
    'on_predict_start': on_predict_start,
 | 
			
		||||
    'on_export_start': on_export_start}
 | 
			
		||||
 | 
			
		||||
@ -16,7 +16,7 @@ def on_pretrain_routine_start(trainer):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def on_batch_end(trainer):
 | 
			
		||||
    _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
 | 
			
		||||
    _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def on_fit_epoch_end(trainer):
 | 
			
		||||
@ -24,6 +24,6 @@ def on_fit_epoch_end(trainer):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
callbacks = {
 | 
			
		||||
    "on_pretrain_routine_start": on_pretrain_routine_start,
 | 
			
		||||
    "on_fit_epoch_end": on_fit_epoch_end,
 | 
			
		||||
    "on_batch_end": on_batch_end}
 | 
			
		||||
    'on_pretrain_routine_start': on_pretrain_routine_start,
 | 
			
		||||
    'on_fit_epoch_end': on_fit_epoch_end,
 | 
			
		||||
    'on_batch_end': on_batch_end}
 | 
			
		||||
 | 
			
		||||
@ -71,7 +71,7 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
 | 
			
		||||
        msg = "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " \
 | 
			
		||||
              "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
 | 
			
		||||
        if max_dim != 1:
 | 
			
		||||
            raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}")
 | 
			
		||||
            raise ValueError(f'imgsz={imgsz} is not a valid image size. {msg}')
 | 
			
		||||
        LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}")
 | 
			
		||||
        imgsz = [max(imgsz)]
 | 
			
		||||
    # Make image size a multiple of the stride
 | 
			
		||||
@ -87,9 +87,9 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
 | 
			
		||||
    return sz
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_version(current: str = "0.0.0",
 | 
			
		||||
                  minimum: str = "0.0.0",
 | 
			
		||||
                  name: str = "version ",
 | 
			
		||||
def check_version(current: str = '0.0.0',
 | 
			
		||||
                  minimum: str = '0.0.0',
 | 
			
		||||
                  name: str = 'version ',
 | 
			
		||||
                  pinned: bool = False,
 | 
			
		||||
                  hard: bool = False,
 | 
			
		||||
                  verbose: bool = False) -> bool:
 | 
			
		||||
@ -109,7 +109,7 @@ def check_version(current: str = "0.0.0",
 | 
			
		||||
    """
 | 
			
		||||
    current, minimum = (pkg.parse_version(x) for x in (current, minimum))
 | 
			
		||||
    result = (current == minimum) if pinned else (current >= minimum)  # bool
 | 
			
		||||
    warning_message = f"WARNING ⚠️ {name}{minimum} is required by YOLOv8, but {name}{current} is currently installed"
 | 
			
		||||
    warning_message = f'WARNING ⚠️ {name}{minimum} is required by YOLOv8, but {name}{current} is currently installed'
 | 
			
		||||
    if hard:
 | 
			
		||||
        assert result, emojis(warning_message)  # assert min requirements met
 | 
			
		||||
    if verbose and not result:
 | 
			
		||||
@ -155,7 +155,7 @@ def check_online() -> bool:
 | 
			
		||||
    """
 | 
			
		||||
    import socket
 | 
			
		||||
    with contextlib.suppress(Exception):
 | 
			
		||||
        host = socket.gethostbyname("www.github.com")
 | 
			
		||||
        host = socket.gethostbyname('www.github.com')
 | 
			
		||||
        socket.create_connection((host, 80), timeout=2)
 | 
			
		||||
        return True
 | 
			
		||||
    return False
 | 
			
		||||
@ -182,7 +182,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
 | 
			
		||||
    file = None
 | 
			
		||||
    if isinstance(requirements, Path):  # requirements.txt file
 | 
			
		||||
        file = requirements.resolve()
 | 
			
		||||
        assert file.exists(), f"{prefix} {file} not found, check failed."
 | 
			
		||||
        assert file.exists(), f'{prefix} {file} not found, check failed.'
 | 
			
		||||
        with file.open() as f:
 | 
			
		||||
            requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
 | 
			
		||||
    elif isinstance(requirements, str):
 | 
			
		||||
@ -200,7 +200,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
 | 
			
		||||
    if s and install and AUTOINSTALL:  # check environment variable
 | 
			
		||||
        LOGGER.info(f"{prefix} YOLOv8 requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
 | 
			
		||||
        try:
 | 
			
		||||
            assert check_online(), "AutoUpdate skipped (offline)"
 | 
			
		||||
            assert check_online(), 'AutoUpdate skipped (offline)'
 | 
			
		||||
            LOGGER.info(subprocess.check_output(f'pip install {s} {cmds}', shell=True).decode())
 | 
			
		||||
            s = f"{prefix} {n} package{'s' * (n > 1)} updated per {file or requirements}\n" \
 | 
			
		||||
                f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
 | 
			
		||||
@ -217,19 +217,19 @@ def check_suffix(file='yolov8n.pt', suffix=('.pt',), msg=''):
 | 
			
		||||
        for f in file if isinstance(file, (list, tuple)) else [file]:
 | 
			
		||||
            s = Path(f).suffix.lower()  # file suffix
 | 
			
		||||
            if len(s):
 | 
			
		||||
                assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
 | 
			
		||||
                assert s in suffix, f'{msg}{f} acceptable suffix is {suffix}'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_yolov5u_filename(file: str):
 | 
			
		||||
    # Replace legacy YOLOv5 filenames with updated YOLOv5u filenames
 | 
			
		||||
    if 'yolov3' in file or 'yolov5' in file and 'u' not in file:
 | 
			
		||||
        original_file = file
 | 
			
		||||
        file = re.sub(r"(.*yolov5([nsmlx]))\.", "\\1u.", file)  # i.e. yolov5n.pt -> yolov5nu.pt
 | 
			
		||||
        file = re.sub(r"(.*yolov3(|-tiny|-spp))\.", "\\1u.", file)  # i.e. yolov3-spp.pt -> yolov3-sppu.pt
 | 
			
		||||
        file = re.sub(r'(.*yolov5([nsmlx]))\.', '\\1u.', file)  # i.e. yolov5n.pt -> yolov5nu.pt
 | 
			
		||||
        file = re.sub(r'(.*yolov3(|-tiny|-spp))\.', '\\1u.', file)  # i.e. yolov3-spp.pt -> yolov3-sppu.pt
 | 
			
		||||
        if file != original_file:
 | 
			
		||||
            LOGGER.info(f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
 | 
			
		||||
                        f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs "
 | 
			
		||||
                        f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n")
 | 
			
		||||
                        f'trained with https://github.com/ultralytics/ultralytics and feature improved performance vs '
 | 
			
		||||
                        f'standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n')
 | 
			
		||||
    return file
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -290,7 +290,7 @@ def check_yolo(verbose=True):
 | 
			
		||||
        # System info
 | 
			
		||||
        gib = 1 << 30  # bytes per GiB
 | 
			
		||||
        ram = psutil.virtual_memory().total
 | 
			
		||||
        total, used, free = shutil.disk_usage("/")
 | 
			
		||||
        total, used, free = shutil.disk_usage('/')
 | 
			
		||||
        s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)'
 | 
			
		||||
        with contextlib.suppress(Exception):  # clear display if ipython is installed
 | 
			
		||||
            from IPython import display
 | 
			
		||||
 | 
			
		||||
@ -22,7 +22,7 @@ def find_free_network_port() -> int:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_ddp_file(trainer):
 | 
			
		||||
    import_path = '.'.join(str(trainer.__class__).split(".")[1:-1])
 | 
			
		||||
    import_path = '.'.join(str(trainer.__class__).split('.')[1:-1])
 | 
			
		||||
 | 
			
		||||
    if not trainer.resume:
 | 
			
		||||
        shutil.rmtree(trainer.save_dir)  # remove the save_dir
 | 
			
		||||
@ -32,9 +32,9 @@ def generate_ddp_file(trainer):
 | 
			
		||||
    trainer = {trainer.__class__.__name__}(cfg=cfg)
 | 
			
		||||
    trainer.train()'''
 | 
			
		||||
    (USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
 | 
			
		||||
    with tempfile.NamedTemporaryFile(prefix="_temp_",
 | 
			
		||||
                                     suffix=f"{id(trainer)}.py",
 | 
			
		||||
                                     mode="w+",
 | 
			
		||||
    with tempfile.NamedTemporaryFile(prefix='_temp_',
 | 
			
		||||
                                     suffix=f'{id(trainer)}.py',
 | 
			
		||||
                                     mode='w+',
 | 
			
		||||
                                     encoding='utf-8',
 | 
			
		||||
                                     dir=USER_CONFIG_DIR / 'DDP',
 | 
			
		||||
                                     delete=False) as file:
 | 
			
		||||
@ -47,18 +47,18 @@ def generate_ddp_command(world_size, trainer):
 | 
			
		||||
 | 
			
		||||
    # Get file and args (do not use sys.argv due to security vulnerability)
 | 
			
		||||
    exclude_args = ['save_dir']
 | 
			
		||||
    args = [f"{k}={v}" for k, v in vars(trainer.args).items() if k not in exclude_args]
 | 
			
		||||
    args = [f'{k}={v}' for k, v in vars(trainer.args).items() if k not in exclude_args]
 | 
			
		||||
    file = generate_ddp_file(trainer)  # if argv[0].endswith('yolo') else os.path.abspath(argv[0])
 | 
			
		||||
 | 
			
		||||
    # Build command
 | 
			
		||||
    torch_distributed_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
 | 
			
		||||
    torch_distributed_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
 | 
			
		||||
    cmd = [
 | 
			
		||||
        sys.executable, "-m", torch_distributed_cmd, "--nproc_per_node", f"{world_size}", "--master_port",
 | 
			
		||||
        f"{find_free_network_port()}", file] + args
 | 
			
		||||
        sys.executable, '-m', torch_distributed_cmd, '--nproc_per_node', f'{world_size}', '--master_port',
 | 
			
		||||
        f'{find_free_network_port()}', file] + args
 | 
			
		||||
    return cmd, file
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ddp_cleanup(trainer, file):
 | 
			
		||||
    # delete temp file if created
 | 
			
		||||
    if f"{id(trainer)}.py" in file:  # if temp_file suffix in file
 | 
			
		||||
    if f'{id(trainer)}.py' in file:  # if temp_file suffix in file
 | 
			
		||||
        os.remove(file)
 | 
			
		||||
 | 
			
		||||
@ -95,14 +95,14 @@ def safe_download(url,
 | 
			
		||||
                        torch.hub.download_url_to_file(url, f, progress=progress)
 | 
			
		||||
                    else:
 | 
			
		||||
                        from ultralytics.yolo.utils import TQDM_BAR_FORMAT
 | 
			
		||||
                        with request.urlopen(url) as response, tqdm(total=int(response.getheader("Content-Length", 0)),
 | 
			
		||||
                        with request.urlopen(url) as response, tqdm(total=int(response.getheader('Content-Length', 0)),
 | 
			
		||||
                                                                    desc=desc,
 | 
			
		||||
                                                                    disable=not progress,
 | 
			
		||||
                                                                    unit='B',
 | 
			
		||||
                                                                    unit_scale=True,
 | 
			
		||||
                                                                    unit_divisor=1024,
 | 
			
		||||
                                                                    bar_format=TQDM_BAR_FORMAT) as pbar:
 | 
			
		||||
                            with open(f, "wb") as f_opened:
 | 
			
		||||
                            with open(f, 'wb') as f_opened:
 | 
			
		||||
                                for data in response:
 | 
			
		||||
                                    f_opened.write(data)
 | 
			
		||||
                                    pbar.update(len(data))
 | 
			
		||||
@ -171,7 +171,7 @@ def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
 | 
			
		||||
                tag, assets = github_assets(repo)  # latest release
 | 
			
		||||
            except Exception:
 | 
			
		||||
                try:
 | 
			
		||||
                    tag = subprocess.check_output(["git", "tag"]).decode().split()[-1]
 | 
			
		||||
                    tag = subprocess.check_output(['git', 'tag']).decode().split()[-1]
 | 
			
		||||
                except Exception:
 | 
			
		||||
                    tag = release
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -24,15 +24,15 @@ to_4tuple = _ntuple(4)
 | 
			
		||||
# `xyxy` means left top and right bottom
 | 
			
		||||
# `xywh` means center x, center y and width, height(yolo format)
 | 
			
		||||
# `ltwh` means left top and width, height(coco format)
 | 
			
		||||
_formats = ["xyxy", "xywh", "ltwh"]
 | 
			
		||||
_formats = ['xyxy', 'xywh', 'ltwh']
 | 
			
		||||
 | 
			
		||||
__all__ = ["Bboxes"]
 | 
			
		||||
__all__ = ['Bboxes']
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Bboxes:
 | 
			
		||||
    """Now only numpy is supported"""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, bboxes, format="xyxy") -> None:
 | 
			
		||||
    def __init__(self, bboxes, format='xyxy') -> None:
 | 
			
		||||
        assert format in _formats
 | 
			
		||||
        bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
 | 
			
		||||
        assert bboxes.ndim == 2
 | 
			
		||||
@ -67,17 +67,17 @@ class Bboxes:
 | 
			
		||||
        assert format in _formats
 | 
			
		||||
        if self.format == format:
 | 
			
		||||
            return
 | 
			
		||||
        elif self.format == "xyxy":
 | 
			
		||||
            bboxes = xyxy2xywh(self.bboxes) if format == "xywh" else xyxy2ltwh(self.bboxes)
 | 
			
		||||
        elif self.format == "xywh":
 | 
			
		||||
            bboxes = xywh2xyxy(self.bboxes) if format == "xyxy" else xywh2ltwh(self.bboxes)
 | 
			
		||||
        elif self.format == 'xyxy':
 | 
			
		||||
            bboxes = xyxy2xywh(self.bboxes) if format == 'xywh' else xyxy2ltwh(self.bboxes)
 | 
			
		||||
        elif self.format == 'xywh':
 | 
			
		||||
            bboxes = xywh2xyxy(self.bboxes) if format == 'xyxy' else xywh2ltwh(self.bboxes)
 | 
			
		||||
        else:
 | 
			
		||||
            bboxes = ltwh2xyxy(self.bboxes) if format == "xyxy" else ltwh2xywh(self.bboxes)
 | 
			
		||||
            bboxes = ltwh2xyxy(self.bboxes) if format == 'xyxy' else ltwh2xywh(self.bboxes)
 | 
			
		||||
        self.bboxes = bboxes
 | 
			
		||||
        self.format = format
 | 
			
		||||
 | 
			
		||||
    def areas(self):
 | 
			
		||||
        self.convert("xyxy")
 | 
			
		||||
        self.convert('xyxy')
 | 
			
		||||
        return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
 | 
			
		||||
 | 
			
		||||
    # def denormalize(self, w, h):
 | 
			
		||||
@ -128,7 +128,7 @@ class Bboxes:
 | 
			
		||||
        return len(self.bboxes)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes":
 | 
			
		||||
    def concatenate(cls, boxes_list: List['Bboxes'], axis=0) -> 'Bboxes':
 | 
			
		||||
        """
 | 
			
		||||
        Concatenates a list of Boxes into a single Bboxes
 | 
			
		||||
 | 
			
		||||
@ -147,7 +147,7 @@ class Bboxes:
 | 
			
		||||
            return boxes_list[0]
 | 
			
		||||
        return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, index) -> "Bboxes":
 | 
			
		||||
    def __getitem__(self, index) -> 'Bboxes':
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            index: int, slice, or a BoolArray
 | 
			
		||||
@ -158,13 +158,13 @@ class Bboxes:
 | 
			
		||||
        if isinstance(index, int):
 | 
			
		||||
            return Bboxes(self.bboxes[index].view(1, -1))
 | 
			
		||||
        b = self.bboxes[index]
 | 
			
		||||
        assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!"
 | 
			
		||||
        assert b.ndim == 2, f'Indexing on Bboxes with {index} failed to return a matrix!'
 | 
			
		||||
        return Bboxes(b)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Instances:
 | 
			
		||||
 | 
			
		||||
    def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
 | 
			
		||||
    def __init__(self, bboxes, segments=None, keypoints=None, bbox_format='xywh', normalized=True) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            bboxes (ndarray): bboxes with shape [N, 4].
 | 
			
		||||
@ -227,7 +227,7 @@ class Instances:
 | 
			
		||||
 | 
			
		||||
    def add_padding(self, padw, padh):
 | 
			
		||||
        # handle rect and mosaic situation
 | 
			
		||||
        assert not self.normalized, "you should add padding with absolute coordinates."
 | 
			
		||||
        assert not self.normalized, 'you should add padding with absolute coordinates.'
 | 
			
		||||
        self._bboxes.add(offset=(padw, padh, padw, padh))
 | 
			
		||||
        self.segments[..., 0] += padw
 | 
			
		||||
        self.segments[..., 1] += padh
 | 
			
		||||
@ -235,7 +235,7 @@ class Instances:
 | 
			
		||||
            self.keypoints[..., 0] += padw
 | 
			
		||||
            self.keypoints[..., 1] += padh
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, index) -> "Instances":
 | 
			
		||||
    def __getitem__(self, index) -> 'Instances':
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            index: int, slice, or a BoolArray
 | 
			
		||||
@ -256,7 +256,7 @@ class Instances:
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def flipud(self, h):
 | 
			
		||||
        if self._bboxes.format == "xyxy":
 | 
			
		||||
        if self._bboxes.format == 'xyxy':
 | 
			
		||||
            y1 = self.bboxes[:, 1].copy()
 | 
			
		||||
            y2 = self.bboxes[:, 3].copy()
 | 
			
		||||
            self.bboxes[:, 1] = h - y2
 | 
			
		||||
@ -268,7 +268,7 @@ class Instances:
 | 
			
		||||
            self.keypoints[..., 1] = h - self.keypoints[..., 1]
 | 
			
		||||
 | 
			
		||||
    def fliplr(self, w):
 | 
			
		||||
        if self._bboxes.format == "xyxy":
 | 
			
		||||
        if self._bboxes.format == 'xyxy':
 | 
			
		||||
            x1 = self.bboxes[:, 0].copy()
 | 
			
		||||
            x2 = self.bboxes[:, 2].copy()
 | 
			
		||||
            self.bboxes[:, 0] = w - x2
 | 
			
		||||
@ -281,10 +281,10 @@ class Instances:
 | 
			
		||||
 | 
			
		||||
    def clip(self, w, h):
 | 
			
		||||
        ori_format = self._bboxes.format
 | 
			
		||||
        self.convert_bbox(format="xyxy")
 | 
			
		||||
        self.convert_bbox(format='xyxy')
 | 
			
		||||
        self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
 | 
			
		||||
        self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
 | 
			
		||||
        if ori_format != "xyxy":
 | 
			
		||||
        if ori_format != 'xyxy':
 | 
			
		||||
            self.convert_bbox(format=ori_format)
 | 
			
		||||
        self.segments[..., 0] = self.segments[..., 0].clip(0, w)
 | 
			
		||||
        self.segments[..., 1] = self.segments[..., 1].clip(0, h)
 | 
			
		||||
@ -304,7 +304,7 @@ class Instances:
 | 
			
		||||
        return len(self.bboxes)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances":
 | 
			
		||||
    def concatenate(cls, instances_list: List['Instances'], axis=0) -> 'Instances':
 | 
			
		||||
        """
 | 
			
		||||
        Concatenates a list of Boxes into a single Bboxes
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -16,7 +16,7 @@ class VarifocalLoss(nn.Module):
 | 
			
		||||
    def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0):
 | 
			
		||||
        weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
 | 
			
		||||
        with torch.cuda.amp.autocast(enabled=False):
 | 
			
		||||
            loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") *
 | 
			
		||||
            loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
 | 
			
		||||
                    weight).sum()
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
@ -52,5 +52,5 @@ class BboxLoss(nn.Module):
 | 
			
		||||
        tr = tl + 1  # target right
 | 
			
		||||
        wl = tr - target  # weight left
 | 
			
		||||
        wr = 1 - wl  # weight right
 | 
			
		||||
        return (F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl +
 | 
			
		||||
                F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr).mean(-1, keepdim=True)
 | 
			
		||||
        return (F.cross_entropy(pred_dist, tl.view(-1), reduction='none').view(tl.shape) * wl +
 | 
			
		||||
                F.cross_entropy(pred_dist, tr.view(-1), reduction='none').view(tl.shape) * wr).mean(-1, keepdim=True)
 | 
			
		||||
 | 
			
		||||
@ -238,14 +238,14 @@ class ConfusionMatrix:
 | 
			
		||||
        nc, nn = self.nc, len(names)  # number of classes, names
 | 
			
		||||
        sn.set(font_scale=1.0 if nc < 50 else 0.8)  # for label size
 | 
			
		||||
        labels = (0 < nn < 99) and (nn == nc)  # apply names to ticklabels
 | 
			
		||||
        ticklabels = (names + ['background']) if labels else "auto"
 | 
			
		||||
        ticklabels = (names + ['background']) if labels else 'auto'
 | 
			
		||||
        with warnings.catch_warnings():
 | 
			
		||||
            warnings.simplefilter('ignore')  # suppress empty matrix RuntimeWarning: All-NaN slice encountered
 | 
			
		||||
            sn.heatmap(array,
 | 
			
		||||
                       ax=ax,
 | 
			
		||||
                       annot=nc < 30,
 | 
			
		||||
                       annot_kws={
 | 
			
		||||
                           "size": 8},
 | 
			
		||||
                           'size': 8},
 | 
			
		||||
                       cmap='Blues',
 | 
			
		||||
                       fmt='.2f',
 | 
			
		||||
                       square=True,
 | 
			
		||||
@ -287,7 +287,7 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
 | 
			
		||||
    ax.set_ylabel('Precision')
 | 
			
		||||
    ax.set_xlim(0, 1)
 | 
			
		||||
    ax.set_ylim(0, 1)
 | 
			
		||||
    ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
 | 
			
		||||
    ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left')
 | 
			
		||||
    ax.set_title('Precision-Recall Curve')
 | 
			
		||||
    fig.savefig(save_dir, dpi=250)
 | 
			
		||||
    plt.close(fig)
 | 
			
		||||
@ -309,7 +309,7 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
 | 
			
		||||
    ax.set_ylabel(ylabel)
 | 
			
		||||
    ax.set_xlim(0, 1)
 | 
			
		||||
    ax.set_ylim(0, 1)
 | 
			
		||||
    ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
 | 
			
		||||
    ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left')
 | 
			
		||||
    ax.set_title(f'{ylabel}-Confidence Curve')
 | 
			
		||||
    fig.savefig(save_dir, dpi=250)
 | 
			
		||||
    plt.close(fig)
 | 
			
		||||
@ -343,7 +343,7 @@ def compute_ap(recall, precision):
 | 
			
		||||
    return ap, mpre, mrec
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), names=(), eps=1e-16, prefix=""):
 | 
			
		||||
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), names=(), eps=1e-16, prefix=''):
 | 
			
		||||
    """ Compute the average precision, given the recall and precision curves.
 | 
			
		||||
    Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
 | 
			
		||||
    # Arguments
 | 
			
		||||
@ -507,7 +507,7 @@ class Metric:
 | 
			
		||||
 | 
			
		||||
class DetMetrics:
 | 
			
		||||
 | 
			
		||||
    def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
 | 
			
		||||
    def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
 | 
			
		||||
        self.save_dir = save_dir
 | 
			
		||||
        self.plot = plot
 | 
			
		||||
        self.names = names
 | 
			
		||||
@ -521,7 +521,7 @@ class DetMetrics:
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def keys(self):
 | 
			
		||||
        return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
 | 
			
		||||
        return ['metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)']
 | 
			
		||||
 | 
			
		||||
    def mean_results(self):
 | 
			
		||||
        return self.box.mean_results()
 | 
			
		||||
@ -543,12 +543,12 @@ class DetMetrics:
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def results_dict(self):
 | 
			
		||||
        return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
 | 
			
		||||
        return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SegmentMetrics:
 | 
			
		||||
 | 
			
		||||
    def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
 | 
			
		||||
    def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
 | 
			
		||||
        self.save_dir = save_dir
 | 
			
		||||
        self.plot = plot
 | 
			
		||||
        self.names = names
 | 
			
		||||
@ -563,7 +563,7 @@ class SegmentMetrics:
 | 
			
		||||
                                    plot=self.plot,
 | 
			
		||||
                                    save_dir=self.save_dir,
 | 
			
		||||
                                    names=self.names,
 | 
			
		||||
                                    prefix="Mask")[2:]
 | 
			
		||||
                                    prefix='Mask')[2:]
 | 
			
		||||
        self.seg.nc = len(self.names)
 | 
			
		||||
        self.seg.update(results_mask)
 | 
			
		||||
        results_box = ap_per_class(tp_b,
 | 
			
		||||
@ -573,15 +573,15 @@ class SegmentMetrics:
 | 
			
		||||
                                   plot=self.plot,
 | 
			
		||||
                                   save_dir=self.save_dir,
 | 
			
		||||
                                   names=self.names,
 | 
			
		||||
                                   prefix="Box")[2:]
 | 
			
		||||
                                   prefix='Box')[2:]
 | 
			
		||||
        self.box.nc = len(self.names)
 | 
			
		||||
        self.box.update(results_box)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def keys(self):
 | 
			
		||||
        return [
 | 
			
		||||
            "metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)",
 | 
			
		||||
            "metrics/precision(M)", "metrics/recall(M)", "metrics/mAP50(M)", "metrics/mAP50-95(M)"]
 | 
			
		||||
            'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)',
 | 
			
		||||
            'metrics/precision(M)', 'metrics/recall(M)', 'metrics/mAP50(M)', 'metrics/mAP50-95(M)']
 | 
			
		||||
 | 
			
		||||
    def mean_results(self):
 | 
			
		||||
        return self.box.mean_results() + self.seg.mean_results()
 | 
			
		||||
@ -604,7 +604,7 @@ class SegmentMetrics:
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def results_dict(self):
 | 
			
		||||
        return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
 | 
			
		||||
        return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ClassifyMetrics:
 | 
			
		||||
@ -626,8 +626,8 @@ class ClassifyMetrics:
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def results_dict(self):
 | 
			
		||||
        return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness]))
 | 
			
		||||
        return dict(zip(self.keys + ['fitness'], [self.top1, self.top5, self.fitness]))
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def keys(self):
 | 
			
		||||
        return ["metrics/accuracy_top1", "metrics/accuracy_top5"]
 | 
			
		||||
        return ['metrics/accuracy_top1', 'metrics/accuracy_top5']
 | 
			
		||||
 | 
			
		||||
@ -715,4 +715,4 @@ def clean_str(s):
 | 
			
		||||
    Returns:
 | 
			
		||||
      (str): a string with special characters replaced by an underscore _
 | 
			
		||||
    """
 | 
			
		||||
    return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
 | 
			
		||||
    return re.sub(pattern='[|@#!¡·$€%&()=?¿^*;:,¨´><+]', repl='_', string=s)
 | 
			
		||||
 | 
			
		||||
@ -61,7 +61,7 @@ def DDP_model(model):
 | 
			
		||||
 | 
			
		||||
def select_device(device='', batch=0, newline=False):
 | 
			
		||||
    # device = None or 'cpu' or 0 or '0' or '0,1,2,3'
 | 
			
		||||
    s = f"Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} "
 | 
			
		||||
    s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
 | 
			
		||||
    device = str(device).lower()
 | 
			
		||||
    for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
 | 
			
		||||
        device = device.replace(remove, '')  # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
 | 
			
		||||
@ -74,15 +74,15 @@ def select_device(device='', batch=0, newline=False):
 | 
			
		||||
        os.environ['CUDA_VISIBLE_DEVICES'] = device  # set environment variable - must be before assert is_available()
 | 
			
		||||
        if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))):
 | 
			
		||||
            LOGGER.info(s)
 | 
			
		||||
            install = "See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no " \
 | 
			
		||||
                      "CUDA devices are seen by torch.\n" if torch.cuda.device_count() == 0 else ""
 | 
			
		||||
            install = 'See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no ' \
 | 
			
		||||
                      'CUDA devices are seen by torch.\n' if torch.cuda.device_count() == 0 else ''
 | 
			
		||||
            raise ValueError(f"Invalid CUDA 'device={device}' requested."
 | 
			
		||||
                             f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
 | 
			
		||||
                             f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
 | 
			
		||||
                             f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}"
 | 
			
		||||
                             f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}"
 | 
			
		||||
                             f'\ntorch.cuda.is_available(): {torch.cuda.is_available()}'
 | 
			
		||||
                             f'\ntorch.cuda.device_count(): {torch.cuda.device_count()}'
 | 
			
		||||
                             f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
 | 
			
		||||
                             f"{install}")
 | 
			
		||||
                             f'{install}')
 | 
			
		||||
 | 
			
		||||
    if not cpu and not mps and torch.cuda.is_available():  # prefer GPU if available
 | 
			
		||||
        devices = device.split(',') if device else '0'  # range(torch.cuda.device_count())  # i.e. 0,1,6,7
 | 
			
		||||
@ -177,7 +177,7 @@ def model_info(model, verbose=False, imgsz=640):
 | 
			
		||||
    fused = ' (fused)' if model.is_fused() else ''
 | 
			
		||||
    fs = f', {flops:.1f} GFLOPs' if flops else ''
 | 
			
		||||
    m = Path(getattr(model, 'yaml_file', '') or model.yaml.get('yaml_file', '')).stem.replace('yolo', 'YOLO') or 'Model'
 | 
			
		||||
    LOGGER.info(f"{m} summary{fused}: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
 | 
			
		||||
    LOGGER.info(f'{m} summary{fused}: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_num_params(model):
 | 
			
		||||
 | 
			
		||||
@ -2,4 +2,4 @@
 | 
			
		||||
 | 
			
		||||
from ultralytics.yolo.v8 import classify, detect, segment
 | 
			
		||||
 | 
			
		||||
__all__ = ["classify", "segment", "detect"]
 | 
			
		||||
__all__ = ['classify', 'segment', 'detect']
 | 
			
		||||
 | 
			
		||||
@ -4,4 +4,4 @@ from ultralytics.yolo.v8.classify.predict import ClassificationPredictor, predic
 | 
			
		||||
from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train
 | 
			
		||||
from ultralytics.yolo.v8.classify.val import ClassificationValidator, val
 | 
			
		||||
 | 
			
		||||
__all__ = ["ClassificationPredictor", "predict", "ClassificationTrainer", "train", "ClassificationValidator", "val"]
 | 
			
		||||
__all__ = ['ClassificationPredictor', 'predict', 'ClassificationTrainer', 'train', 'ClassificationValidator', 'val']
 | 
			
		||||
 | 
			
		||||
@ -28,7 +28,7 @@ class ClassificationPredictor(BasePredictor):
 | 
			
		||||
 | 
			
		||||
    def write_results(self, idx, results, batch):
 | 
			
		||||
        p, im, im0 = batch
 | 
			
		||||
        log_string = ""
 | 
			
		||||
        log_string = ''
 | 
			
		||||
        if len(im.shape) == 3:
 | 
			
		||||
            im = im[None]  # expand for batch dim
 | 
			
		||||
        self.seen += 1
 | 
			
		||||
@ -65,9 +65,9 @@ class ClassificationPredictor(BasePredictor):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def predict(cfg=DEFAULT_CFG, use_python=False):
 | 
			
		||||
    model = cfg.model or "yolov8n-cls.pt"  # or "resnet18"
 | 
			
		||||
    source = cfg.source if cfg.source is not None else ROOT / "assets" if (ROOT / "assets").exists() \
 | 
			
		||||
        else "https://ultralytics.com/images/bus.jpg"
 | 
			
		||||
    model = cfg.model or 'yolov8n-cls.pt'  # or "resnet18"
 | 
			
		||||
    source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
 | 
			
		||||
        else 'https://ultralytics.com/images/bus.jpg'
 | 
			
		||||
 | 
			
		||||
    args = dict(model=model, source=source)
 | 
			
		||||
    if use_python:
 | 
			
		||||
@ -78,5 +78,5 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
 | 
			
		||||
        predictor.predict_cli()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    predict()
 | 
			
		||||
 | 
			
		||||
@ -16,14 +16,14 @@ class ClassificationTrainer(BaseTrainer):
 | 
			
		||||
    def __init__(self, cfg=DEFAULT_CFG, overrides=None):
 | 
			
		||||
        if overrides is None:
 | 
			
		||||
            overrides = {}
 | 
			
		||||
        overrides["task"] = "classify"
 | 
			
		||||
        overrides['task'] = 'classify'
 | 
			
		||||
        super().__init__(cfg, overrides)
 | 
			
		||||
 | 
			
		||||
    def set_model_attributes(self):
 | 
			
		||||
        self.model.names = self.data["names"]
 | 
			
		||||
        self.model.names = self.data['names']
 | 
			
		||||
 | 
			
		||||
    def get_model(self, cfg=None, weights=None, verbose=True):
 | 
			
		||||
        model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
 | 
			
		||||
        model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
 | 
			
		||||
        if weights:
 | 
			
		||||
            model.load(weights)
 | 
			
		||||
 | 
			
		||||
@ -53,11 +53,11 @@ class ClassificationTrainer(BaseTrainer):
 | 
			
		||||
 | 
			
		||||
        model = str(self.model)
 | 
			
		||||
        # Load a YOLO model locally, from torchvision, or from Ultralytics assets
 | 
			
		||||
        if model.endswith(".pt"):
 | 
			
		||||
        if model.endswith('.pt'):
 | 
			
		||||
            self.model, _ = attempt_load_one_weight(model, device='cpu')
 | 
			
		||||
            for p in self.model.parameters():
 | 
			
		||||
                p.requires_grad = True  # for training
 | 
			
		||||
        elif model.endswith(".yaml"):
 | 
			
		||||
        elif model.endswith('.yaml'):
 | 
			
		||||
            self.model = self.get_model(cfg=model)
 | 
			
		||||
        elif model in torchvision.models.__dict__:
 | 
			
		||||
            pretrained = True
 | 
			
		||||
@ -67,15 +67,15 @@ class ClassificationTrainer(BaseTrainer):
 | 
			
		||||
 | 
			
		||||
        return  # dont return ckpt. Classification doesn't support resume
 | 
			
		||||
 | 
			
		||||
    def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
 | 
			
		||||
    def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
 | 
			
		||||
        loader = build_classification_dataloader(path=dataset_path,
 | 
			
		||||
                                                 imgsz=self.args.imgsz,
 | 
			
		||||
                                                 batch_size=batch_size if mode == "train" else (batch_size * 2),
 | 
			
		||||
                                                 augment=mode == "train",
 | 
			
		||||
                                                 batch_size=batch_size if mode == 'train' else (batch_size * 2),
 | 
			
		||||
                                                 augment=mode == 'train',
 | 
			
		||||
                                                 rank=rank,
 | 
			
		||||
                                                 workers=self.args.workers)
 | 
			
		||||
        # Attach inference transforms
 | 
			
		||||
        if mode != "train":
 | 
			
		||||
        if mode != 'train':
 | 
			
		||||
            if is_parallel(self.model):
 | 
			
		||||
                self.model.module.transforms = loader.dataset.torch_transforms
 | 
			
		||||
            else:
 | 
			
		||||
@ -83,8 +83,8 @@ class ClassificationTrainer(BaseTrainer):
 | 
			
		||||
        return loader
 | 
			
		||||
 | 
			
		||||
    def preprocess_batch(self, batch):
 | 
			
		||||
        batch["img"] = batch["img"].to(self.device)
 | 
			
		||||
        batch["cls"] = batch["cls"].to(self.device)
 | 
			
		||||
        batch['img'] = batch['img'].to(self.device)
 | 
			
		||||
        batch['cls'] = batch['cls'].to(self.device)
 | 
			
		||||
        return batch
 | 
			
		||||
 | 
			
		||||
    def progress_string(self):
 | 
			
		||||
@ -96,7 +96,7 @@ class ClassificationTrainer(BaseTrainer):
 | 
			
		||||
        return v8.classify.ClassificationValidator(self.test_loader, self.save_dir, logger=self.console)
 | 
			
		||||
 | 
			
		||||
    def criterion(self, preds, batch):
 | 
			
		||||
        loss = torch.nn.functional.cross_entropy(preds, batch["cls"], reduction='sum') / self.args.nbs
 | 
			
		||||
        loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs
 | 
			
		||||
        loss_items = loss.detach()
 | 
			
		||||
        return loss, loss_items
 | 
			
		||||
 | 
			
		||||
@ -112,12 +112,12 @@ class ClassificationTrainer(BaseTrainer):
 | 
			
		||||
    #     else:
 | 
			
		||||
    #         return keys
 | 
			
		||||
 | 
			
		||||
    def label_loss_items(self, loss_items=None, prefix="train"):
 | 
			
		||||
    def label_loss_items(self, loss_items=None, prefix='train'):
 | 
			
		||||
        """
 | 
			
		||||
        Returns a loss dict with labelled training loss items tensor
 | 
			
		||||
        """
 | 
			
		||||
        # Not needed for classification but necessary for segmentation & detection
 | 
			
		||||
        keys = [f"{prefix}/{x}" for x in self.loss_names]
 | 
			
		||||
        keys = [f'{prefix}/{x}' for x in self.loss_names]
 | 
			
		||||
        if loss_items is None:
 | 
			
		||||
            return keys
 | 
			
		||||
        loss_items = [round(float(loss_items), 5)]
 | 
			
		||||
@ -140,8 +140,8 @@ class ClassificationTrainer(BaseTrainer):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def train(cfg=DEFAULT_CFG, use_python=False):
 | 
			
		||||
    model = cfg.model or "yolov8n-cls.pt"  # or "resnet18"
 | 
			
		||||
    data = cfg.data or "mnist160"  # or yolo.ClassificationDataset("mnist")
 | 
			
		||||
    model = cfg.model or 'yolov8n-cls.pt'  # or "resnet18"
 | 
			
		||||
    data = cfg.data or 'mnist160'  # or yolo.ClassificationDataset("mnist")
 | 
			
		||||
    device = cfg.device if cfg.device is not None else ''
 | 
			
		||||
 | 
			
		||||
    args = dict(model=model, data=data, device=device)
 | 
			
		||||
@ -153,5 +153,5 @@ def train(cfg=DEFAULT_CFG, use_python=False):
 | 
			
		||||
        trainer.train()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    train()
 | 
			
		||||
 | 
			
		||||
@ -21,14 +21,14 @@ class ClassificationValidator(BaseValidator):
 | 
			
		||||
        self.targets = []
 | 
			
		||||
 | 
			
		||||
    def preprocess(self, batch):
 | 
			
		||||
        batch["img"] = batch["img"].to(self.device, non_blocking=True)
 | 
			
		||||
        batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
 | 
			
		||||
        batch["cls"] = batch["cls"].to(self.device)
 | 
			
		||||
        batch['img'] = batch['img'].to(self.device, non_blocking=True)
 | 
			
		||||
        batch['img'] = batch['img'].half() if self.args.half else batch['img'].float()
 | 
			
		||||
        batch['cls'] = batch['cls'].to(self.device)
 | 
			
		||||
        return batch
 | 
			
		||||
 | 
			
		||||
    def update_metrics(self, preds, batch):
 | 
			
		||||
        self.pred.append(preds.argsort(1, descending=True)[:, :5])
 | 
			
		||||
        self.targets.append(batch["cls"])
 | 
			
		||||
        self.targets.append(batch['cls'])
 | 
			
		||||
 | 
			
		||||
    def get_stats(self):
 | 
			
		||||
        self.metrics.process(self.targets, self.pred)
 | 
			
		||||
@ -42,12 +42,12 @@ class ClassificationValidator(BaseValidator):
 | 
			
		||||
 | 
			
		||||
    def print_results(self):
 | 
			
		||||
        pf = '%22s' + '%11.3g' * len(self.metrics.keys)  # print format
 | 
			
		||||
        self.logger.info(pf % ("all", self.metrics.top1, self.metrics.top5))
 | 
			
		||||
        self.logger.info(pf % ('all', self.metrics.top1, self.metrics.top5))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def val(cfg=DEFAULT_CFG, use_python=False):
 | 
			
		||||
    model = cfg.model or "yolov8n-cls.pt"  # or "resnet18"
 | 
			
		||||
    data = cfg.data or "mnist160"
 | 
			
		||||
    model = cfg.model or 'yolov8n-cls.pt'  # or "resnet18"
 | 
			
		||||
    data = cfg.data or 'mnist160'
 | 
			
		||||
 | 
			
		||||
    args = dict(model=model, data=data)
 | 
			
		||||
    if use_python:
 | 
			
		||||
@ -58,5 +58,5 @@ def val(cfg=DEFAULT_CFG, use_python=False):
 | 
			
		||||
        validator(model=args['model'])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    val()
 | 
			
		||||
 | 
			
		||||
@ -4,4 +4,4 @@ from .predict import DetectionPredictor, predict
 | 
			
		||||
from .train import DetectionTrainer, train
 | 
			
		||||
from .val import DetectionValidator, val
 | 
			
		||||
 | 
			
		||||
__all__ = ["DetectionPredictor", "predict", "DetectionTrainer", "train", "DetectionValidator", "val"]
 | 
			
		||||
__all__ = ['DetectionPredictor', 'predict', 'DetectionTrainer', 'train', 'DetectionValidator', 'val']
 | 
			
		||||
 | 
			
		||||
@ -37,7 +37,7 @@ class DetectionPredictor(BasePredictor):
 | 
			
		||||
 | 
			
		||||
    def write_results(self, idx, results, batch):
 | 
			
		||||
        p, im, im0 = batch
 | 
			
		||||
        log_string = ""
 | 
			
		||||
        log_string = ''
 | 
			
		||||
        if len(im.shape) == 3:
 | 
			
		||||
            im = im[None]  # expand for batch dim
 | 
			
		||||
        self.seen += 1
 | 
			
		||||
@ -69,7 +69,7 @@ class DetectionPredictor(BasePredictor):
 | 
			
		||||
                    f.write(('%g ' * len(line)).rstrip() % line + '\n')
 | 
			
		||||
            if self.args.save or self.args.save_crop or self.args.show:  # Add bbox to image
 | 
			
		||||
                c = int(cls)  # integer class
 | 
			
		||||
                name = f"id:{int(d.id.item())} {self.model.names[c]}" if d.id is not None else self.model.names[c]
 | 
			
		||||
                name = f'id:{int(d.id.item())} {self.model.names[c]}' if d.id is not None else self.model.names[c]
 | 
			
		||||
                label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}')
 | 
			
		||||
                self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
 | 
			
		||||
            if self.args.save_crop:
 | 
			
		||||
@ -82,9 +82,9 @@ class DetectionPredictor(BasePredictor):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def predict(cfg=DEFAULT_CFG, use_python=False):
 | 
			
		||||
    model = cfg.model or "yolov8n.pt"
 | 
			
		||||
    source = cfg.source if cfg.source is not None else ROOT / "assets" if (ROOT / "assets").exists() \
 | 
			
		||||
        else "https://ultralytics.com/images/bus.jpg"
 | 
			
		||||
    model = cfg.model or 'yolov8n.pt'
 | 
			
		||||
    source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
 | 
			
		||||
        else 'https://ultralytics.com/images/bus.jpg'
 | 
			
		||||
 | 
			
		||||
    args = dict(model=model, source=source)
 | 
			
		||||
    if use_python:
 | 
			
		||||
@ -95,5 +95,5 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
 | 
			
		||||
        predictor.predict_cli()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    predict()
 | 
			
		||||
 | 
			
		||||
@ -20,7 +20,7 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
 | 
			
		||||
# BaseTrainer python usage
 | 
			
		||||
class DetectionTrainer(BaseTrainer):
 | 
			
		||||
 | 
			
		||||
    def get_dataloader(self, dataset_path, batch_size, mode="train", rank=0):
 | 
			
		||||
    def get_dataloader(self, dataset_path, batch_size, mode='train', rank=0):
 | 
			
		||||
        # TODO: manage splits differently
 | 
			
		||||
        # calculate stride - check if model is initialized
 | 
			
		||||
        gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
 | 
			
		||||
@ -29,21 +29,21 @@ class DetectionTrainer(BaseTrainer):
 | 
			
		||||
                                 batch_size=batch_size,
 | 
			
		||||
                                 stride=gs,
 | 
			
		||||
                                 hyp=vars(self.args),
 | 
			
		||||
                                 augment=mode == "train",
 | 
			
		||||
                                 augment=mode == 'train',
 | 
			
		||||
                                 cache=self.args.cache,
 | 
			
		||||
                                 pad=0 if mode == "train" else 0.5,
 | 
			
		||||
                                 rect=self.args.rect or mode == "val",
 | 
			
		||||
                                 pad=0 if mode == 'train' else 0.5,
 | 
			
		||||
                                 rect=self.args.rect or mode == 'val',
 | 
			
		||||
                                 rank=rank,
 | 
			
		||||
                                 workers=self.args.workers,
 | 
			
		||||
                                 close_mosaic=self.args.close_mosaic != 0,
 | 
			
		||||
                                 prefix=colorstr(f'{mode}: '),
 | 
			
		||||
                                 shuffle=mode == "train",
 | 
			
		||||
                                 shuffle=mode == 'train',
 | 
			
		||||
                                 seed=self.args.seed)[0] if self.args.v5loader else \
 | 
			
		||||
            build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode,
 | 
			
		||||
                             rect=mode == "val", names=self.data['names'])[0]
 | 
			
		||||
                             rect=mode == 'val', names=self.data['names'])[0]
 | 
			
		||||
 | 
			
		||||
    def preprocess_batch(self, batch):
 | 
			
		||||
        batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
 | 
			
		||||
        batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255
 | 
			
		||||
        return batch
 | 
			
		||||
 | 
			
		||||
    def set_model_attributes(self):
 | 
			
		||||
@ -51,13 +51,13 @@ class DetectionTrainer(BaseTrainer):
 | 
			
		||||
        # self.args.box *= 3 / nl  # scale to layers
 | 
			
		||||
        # self.args.cls *= self.data["nc"] / 80 * 3 / nl  # scale to classes and layers
 | 
			
		||||
        # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl  # scale to image size and layers
 | 
			
		||||
        self.model.nc = self.data["nc"]  # attach number of classes to model
 | 
			
		||||
        self.model.names = self.data["names"]  # attach class names to model
 | 
			
		||||
        self.model.nc = self.data['nc']  # attach number of classes to model
 | 
			
		||||
        self.model.names = self.data['names']  # attach class names to model
 | 
			
		||||
        self.model.args = self.args  # attach hyperparameters to model
 | 
			
		||||
        # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
 | 
			
		||||
 | 
			
		||||
    def get_model(self, cfg=None, weights=None, verbose=True):
 | 
			
		||||
        model = DetectionModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
 | 
			
		||||
        model = DetectionModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
 | 
			
		||||
        if weights:
 | 
			
		||||
            model.load(weights)
 | 
			
		||||
 | 
			
		||||
@ -75,12 +75,12 @@ class DetectionTrainer(BaseTrainer):
 | 
			
		||||
            self.compute_loss = Loss(de_parallel(self.model))
 | 
			
		||||
        return self.compute_loss(preds, batch)
 | 
			
		||||
 | 
			
		||||
    def label_loss_items(self, loss_items=None, prefix="train"):
 | 
			
		||||
    def label_loss_items(self, loss_items=None, prefix='train'):
 | 
			
		||||
        """
 | 
			
		||||
        Returns a loss dict with labelled training loss items tensor
 | 
			
		||||
        """
 | 
			
		||||
        # Not needed for classification but necessary for segmentation & detection
 | 
			
		||||
        keys = [f"{prefix}/{x}" for x in self.loss_names]
 | 
			
		||||
        keys = [f'{prefix}/{x}' for x in self.loss_names]
 | 
			
		||||
        if loss_items is not None:
 | 
			
		||||
            loss_items = [round(float(x), 5) for x in loss_items]  # convert tensors to 5 decimal place floats
 | 
			
		||||
            return dict(zip(keys, loss_items))
 | 
			
		||||
@ -92,12 +92,12 @@ class DetectionTrainer(BaseTrainer):
 | 
			
		||||
                (4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
 | 
			
		||||
 | 
			
		||||
    def plot_training_samples(self, batch, ni):
 | 
			
		||||
        plot_images(images=batch["img"],
 | 
			
		||||
                    batch_idx=batch["batch_idx"],
 | 
			
		||||
                    cls=batch["cls"].squeeze(-1),
 | 
			
		||||
                    bboxes=batch["bboxes"],
 | 
			
		||||
                    paths=batch["im_file"],
 | 
			
		||||
                    fname=self.save_dir / f"train_batch{ni}.jpg")
 | 
			
		||||
        plot_images(images=batch['img'],
 | 
			
		||||
                    batch_idx=batch['batch_idx'],
 | 
			
		||||
                    cls=batch['cls'].squeeze(-1),
 | 
			
		||||
                    bboxes=batch['bboxes'],
 | 
			
		||||
                    paths=batch['im_file'],
 | 
			
		||||
                    fname=self.save_dir / f'train_batch{ni}.jpg')
 | 
			
		||||
 | 
			
		||||
    def plot_metrics(self):
 | 
			
		||||
        plot_results(file=self.csv)  # save results.png
 | 
			
		||||
@ -169,7 +169,7 @@ class Loss:
 | 
			
		||||
        anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
 | 
			
		||||
 | 
			
		||||
        # targets
 | 
			
		||||
        targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
 | 
			
		||||
        targets = torch.cat((batch['batch_idx'].view(-1, 1), batch['cls'].view(-1, 1), batch['bboxes']), 1)
 | 
			
		||||
        targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
 | 
			
		||||
        gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy
 | 
			
		||||
        mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
 | 
			
		||||
@ -201,8 +201,8 @@ class Loss:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def train(cfg=DEFAULT_CFG, use_python=False):
 | 
			
		||||
    model = cfg.model or "yolov8n.pt"
 | 
			
		||||
    data = cfg.data or "coco128.yaml"  # or yolo.ClassificationDataset("mnist")
 | 
			
		||||
    model = cfg.model or 'yolov8n.pt'
 | 
			
		||||
    data = cfg.data or 'coco128.yaml'  # or yolo.ClassificationDataset("mnist")
 | 
			
		||||
    device = cfg.device if cfg.device is not None else ''
 | 
			
		||||
 | 
			
		||||
    args = dict(model=model, data=data, device=device)
 | 
			
		||||
@ -214,5 +214,5 @@ def train(cfg=DEFAULT_CFG, use_python=False):
 | 
			
		||||
        trainer.train()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    train()
 | 
			
		||||
 | 
			
		||||
@ -28,13 +28,13 @@ class DetectionValidator(BaseValidator):
 | 
			
		||||
        self.niou = self.iouv.numel()
 | 
			
		||||
 | 
			
		||||
    def preprocess(self, batch):
 | 
			
		||||
        batch["img"] = batch["img"].to(self.device, non_blocking=True)
 | 
			
		||||
        batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
 | 
			
		||||
        for k in ["batch_idx", "cls", "bboxes"]:
 | 
			
		||||
        batch['img'] = batch['img'].to(self.device, non_blocking=True)
 | 
			
		||||
        batch['img'] = (batch['img'].half() if self.args.half else batch['img'].float()) / 255
 | 
			
		||||
        for k in ['batch_idx', 'cls', 'bboxes']:
 | 
			
		||||
            batch[k] = batch[k].to(self.device)
 | 
			
		||||
 | 
			
		||||
        nb = len(batch["img"])
 | 
			
		||||
        self.lb = [torch.cat([batch["cls"], batch["bboxes"]], dim=-1)[batch["batch_idx"] == i]
 | 
			
		||||
        nb = len(batch['img'])
 | 
			
		||||
        self.lb = [torch.cat([batch['cls'], batch['bboxes']], dim=-1)[batch['batch_idx'] == i]
 | 
			
		||||
                   for i in range(nb)] if self.args.save_hybrid else []  # for autolabelling
 | 
			
		||||
 | 
			
		||||
        return batch
 | 
			
		||||
@ -54,7 +54,7 @@ class DetectionValidator(BaseValidator):
 | 
			
		||||
        self.stats = []
 | 
			
		||||
 | 
			
		||||
    def get_desc(self):
 | 
			
		||||
        return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)")
 | 
			
		||||
        return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)')
 | 
			
		||||
 | 
			
		||||
    def postprocess(self, preds):
 | 
			
		||||
        preds = ops.non_max_suppression(preds,
 | 
			
		||||
@ -69,11 +69,11 @@ class DetectionValidator(BaseValidator):
 | 
			
		||||
    def update_metrics(self, preds, batch):
 | 
			
		||||
        # Metrics
 | 
			
		||||
        for si, pred in enumerate(preds):
 | 
			
		||||
            idx = batch["batch_idx"] == si
 | 
			
		||||
            cls = batch["cls"][idx]
 | 
			
		||||
            bbox = batch["bboxes"][idx]
 | 
			
		||||
            idx = batch['batch_idx'] == si
 | 
			
		||||
            cls = batch['cls'][idx]
 | 
			
		||||
            bbox = batch['bboxes'][idx]
 | 
			
		||||
            nl, npr = cls.shape[0], pred.shape[0]  # number of labels, predictions
 | 
			
		||||
            shape = batch["ori_shape"][si]
 | 
			
		||||
            shape = batch['ori_shape'][si]
 | 
			
		||||
            correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device)  # init
 | 
			
		||||
            self.seen += 1
 | 
			
		||||
 | 
			
		||||
@ -88,16 +88,16 @@ class DetectionValidator(BaseValidator):
 | 
			
		||||
            if self.args.single_cls:
 | 
			
		||||
                pred[:, 5] = 0
 | 
			
		||||
            predn = pred.clone()
 | 
			
		||||
            ops.scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape,
 | 
			
		||||
                            ratio_pad=batch["ratio_pad"][si])  # native-space pred
 | 
			
		||||
            ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
 | 
			
		||||
                            ratio_pad=batch['ratio_pad'][si])  # native-space pred
 | 
			
		||||
 | 
			
		||||
            # Evaluate
 | 
			
		||||
            if nl:
 | 
			
		||||
                height, width = batch["img"].shape[2:]
 | 
			
		||||
                height, width = batch['img'].shape[2:]
 | 
			
		||||
                tbox = ops.xywh2xyxy(bbox) * torch.tensor(
 | 
			
		||||
                    (width, height, width, height), device=self.device)  # target boxes
 | 
			
		||||
                ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape,
 | 
			
		||||
                                ratio_pad=batch["ratio_pad"][si])  # native-space labels
 | 
			
		||||
                ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
 | 
			
		||||
                                ratio_pad=batch['ratio_pad'][si])  # native-space labels
 | 
			
		||||
                labelsn = torch.cat((cls, tbox), 1)  # native-space labels
 | 
			
		||||
                correct_bboxes = self._process_batch(predn, labelsn)
 | 
			
		||||
                # TODO: maybe remove these `self.` arguments as they already are member variable
 | 
			
		||||
@ -107,7 +107,7 @@ class DetectionValidator(BaseValidator):
 | 
			
		||||
 | 
			
		||||
            # Save
 | 
			
		||||
            if self.args.save_json:
 | 
			
		||||
                self.pred_to_json(predn, batch["im_file"][si])
 | 
			
		||||
                self.pred_to_json(predn, batch['im_file'][si])
 | 
			
		||||
            # if self.args.save_txt:
 | 
			
		||||
            #    save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
 | 
			
		||||
 | 
			
		||||
@ -120,7 +120,7 @@ class DetectionValidator(BaseValidator):
 | 
			
		||||
 | 
			
		||||
    def print_results(self):
 | 
			
		||||
        pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys)  # print format
 | 
			
		||||
        self.logger.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
 | 
			
		||||
        self.logger.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
 | 
			
		||||
        if self.nt_per_class.sum() == 0:
 | 
			
		||||
            self.logger.warning(
 | 
			
		||||
                f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels')
 | 
			
		||||
@ -175,21 +175,21 @@ class DetectionValidator(BaseValidator):
 | 
			
		||||
                                 shuffle=False,
 | 
			
		||||
                                 seed=self.args.seed)[0] if self.args.v5loader else \
 | 
			
		||||
            build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, names=self.data['names'],
 | 
			
		||||
                             mode="val")[0]
 | 
			
		||||
                             mode='val')[0]
 | 
			
		||||
 | 
			
		||||
    def plot_val_samples(self, batch, ni):
 | 
			
		||||
        plot_images(batch["img"],
 | 
			
		||||
                    batch["batch_idx"],
 | 
			
		||||
                    batch["cls"].squeeze(-1),
 | 
			
		||||
                    batch["bboxes"],
 | 
			
		||||
                    paths=batch["im_file"],
 | 
			
		||||
                    fname=self.save_dir / f"val_batch{ni}_labels.jpg",
 | 
			
		||||
        plot_images(batch['img'],
 | 
			
		||||
                    batch['batch_idx'],
 | 
			
		||||
                    batch['cls'].squeeze(-1),
 | 
			
		||||
                    batch['bboxes'],
 | 
			
		||||
                    paths=batch['im_file'],
 | 
			
		||||
                    fname=self.save_dir / f'val_batch{ni}_labels.jpg',
 | 
			
		||||
                    names=self.names)
 | 
			
		||||
 | 
			
		||||
    def plot_predictions(self, batch, preds, ni):
 | 
			
		||||
        plot_images(batch["img"],
 | 
			
		||||
        plot_images(batch['img'],
 | 
			
		||||
                    *output_to_target(preds, max_det=15),
 | 
			
		||||
                    paths=batch["im_file"],
 | 
			
		||||
                    paths=batch['im_file'],
 | 
			
		||||
                    fname=self.save_dir / f'val_batch{ni}_pred.jpg',
 | 
			
		||||
                    names=self.names)  # pred
 | 
			
		||||
 | 
			
		||||
@ -207,8 +207,8 @@ class DetectionValidator(BaseValidator):
 | 
			
		||||
 | 
			
		||||
    def eval_json(self, stats):
 | 
			
		||||
        if self.args.save_json and self.is_coco and len(self.jdict):
 | 
			
		||||
            anno_json = self.data['path'] / "annotations/instances_val2017.json"  # annotations
 | 
			
		||||
            pred_json = self.save_dir / "predictions.json"  # predictions
 | 
			
		||||
            anno_json = self.data['path'] / 'annotations/instances_val2017.json'  # annotations
 | 
			
		||||
            pred_json = self.save_dir / 'predictions.json'  # predictions
 | 
			
		||||
            self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
 | 
			
		||||
            try:  # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
 | 
			
		||||
                check_requirements('pycocotools>=2.0.6')
 | 
			
		||||
@ -216,7 +216,7 @@ class DetectionValidator(BaseValidator):
 | 
			
		||||
                from pycocotools.cocoeval import COCOeval  # noqa
 | 
			
		||||
 | 
			
		||||
                for x in anno_json, pred_json:
 | 
			
		||||
                    assert x.is_file(), f"{x} file not found"
 | 
			
		||||
                    assert x.is_file(), f'{x} file not found'
 | 
			
		||||
                anno = COCO(str(anno_json))  # init annotations api
 | 
			
		||||
                pred = anno.loadRes(str(pred_json))  # init predictions api (must pass string, not Path)
 | 
			
		||||
                eval = COCOeval(anno, pred, 'bbox')
 | 
			
		||||
@ -232,8 +232,8 @@ class DetectionValidator(BaseValidator):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def val(cfg=DEFAULT_CFG, use_python=False):
 | 
			
		||||
    model = cfg.model or "yolov8n.pt"
 | 
			
		||||
    data = cfg.data or "coco128.yaml"
 | 
			
		||||
    model = cfg.model or 'yolov8n.pt'
 | 
			
		||||
    data = cfg.data or 'coco128.yaml'
 | 
			
		||||
 | 
			
		||||
    args = dict(model=model, data=data)
 | 
			
		||||
    if use_python:
 | 
			
		||||
@ -244,5 +244,5 @@ def val(cfg=DEFAULT_CFG, use_python=False):
 | 
			
		||||
        validator(model=args['model'])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    val()
 | 
			
		||||
 | 
			
		||||
@ -4,4 +4,4 @@ from .predict import SegmentationPredictor, predict
 | 
			
		||||
from .train import SegmentationTrainer, train
 | 
			
		||||
from .val import SegmentationValidator, val
 | 
			
		||||
 | 
			
		||||
__all__ = ["SegmentationPredictor", "predict", "SegmentationTrainer", "train", "SegmentationValidator", "val"]
 | 
			
		||||
__all__ = ['SegmentationPredictor', 'predict', 'SegmentationTrainer', 'train', 'SegmentationValidator', 'val']
 | 
			
		||||
 | 
			
		||||
@ -39,7 +39,7 @@ class SegmentationPredictor(DetectionPredictor):
 | 
			
		||||
 | 
			
		||||
    def write_results(self, idx, results, batch):
 | 
			
		||||
        p, im, im0 = batch
 | 
			
		||||
        log_string = ""
 | 
			
		||||
        log_string = ''
 | 
			
		||||
        if len(im.shape) == 3:
 | 
			
		||||
            im = im[None]  # expand for batch dim
 | 
			
		||||
        self.seen += 1
 | 
			
		||||
@ -84,7 +84,7 @@ class SegmentationPredictor(DetectionPredictor):
 | 
			
		||||
 | 
			
		||||
            if self.args.save or self.args.save_crop or self.args.show:  # Add bbox to image
 | 
			
		||||
                c = int(cls)  # integer class
 | 
			
		||||
                name = f"id:{int(d.id.item())} {self.model.names[c]}" if d.id is not None else self.model.names[c]
 | 
			
		||||
                name = f'id:{int(d.id.item())} {self.model.names[c]}' if d.id is not None else self.model.names[c]
 | 
			
		||||
                label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}')
 | 
			
		||||
                self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) if self.args.boxes else None
 | 
			
		||||
            if self.args.save_crop:
 | 
			
		||||
@ -97,9 +97,9 @@ class SegmentationPredictor(DetectionPredictor):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def predict(cfg=DEFAULT_CFG, use_python=False):
 | 
			
		||||
    model = cfg.model or "yolov8n-seg.pt"
 | 
			
		||||
    source = cfg.source if cfg.source is not None else ROOT / "assets" if (ROOT / "assets").exists() \
 | 
			
		||||
        else "https://ultralytics.com/images/bus.jpg"
 | 
			
		||||
    model = cfg.model or 'yolov8n-seg.pt'
 | 
			
		||||
    source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
 | 
			
		||||
        else 'https://ultralytics.com/images/bus.jpg'
 | 
			
		||||
 | 
			
		||||
    args = dict(model=model, source=source)
 | 
			
		||||
    if use_python:
 | 
			
		||||
@ -110,5 +110,5 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
 | 
			
		||||
        predictor.predict_cli()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    predict()
 | 
			
		||||
 | 
			
		||||
@ -20,11 +20,11 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
 | 
			
		||||
    def __init__(self, cfg=DEFAULT_CFG, overrides=None):
 | 
			
		||||
        if overrides is None:
 | 
			
		||||
            overrides = {}
 | 
			
		||||
        overrides["task"] = "segment"
 | 
			
		||||
        overrides['task'] = 'segment'
 | 
			
		||||
        super().__init__(cfg, overrides)
 | 
			
		||||
 | 
			
		||||
    def get_model(self, cfg=None, weights=None, verbose=True):
 | 
			
		||||
        model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
 | 
			
		||||
        model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
 | 
			
		||||
        if weights:
 | 
			
		||||
            model.load(weights)
 | 
			
		||||
 | 
			
		||||
@ -43,13 +43,13 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
 | 
			
		||||
        return self.compute_loss(preds, batch)
 | 
			
		||||
 | 
			
		||||
    def plot_training_samples(self, batch, ni):
 | 
			
		||||
        images = batch["img"]
 | 
			
		||||
        masks = batch["masks"]
 | 
			
		||||
        cls = batch["cls"].squeeze(-1)
 | 
			
		||||
        bboxes = batch["bboxes"]
 | 
			
		||||
        paths = batch["im_file"]
 | 
			
		||||
        batch_idx = batch["batch_idx"]
 | 
			
		||||
        plot_images(images, batch_idx, cls, bboxes, masks, paths=paths, fname=self.save_dir / f"train_batch{ni}.jpg")
 | 
			
		||||
        images = batch['img']
 | 
			
		||||
        masks = batch['masks']
 | 
			
		||||
        cls = batch['cls'].squeeze(-1)
 | 
			
		||||
        bboxes = batch['bboxes']
 | 
			
		||||
        paths = batch['im_file']
 | 
			
		||||
        batch_idx = batch['batch_idx']
 | 
			
		||||
        plot_images(images, batch_idx, cls, bboxes, masks, paths=paths, fname=self.save_dir / f'train_batch{ni}.jpg')
 | 
			
		||||
 | 
			
		||||
    def plot_metrics(self):
 | 
			
		||||
        plot_results(file=self.csv, segment=True)  # save results.png
 | 
			
		||||
@ -80,15 +80,15 @@ class SegLoss(Loss):
 | 
			
		||||
        anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
 | 
			
		||||
 | 
			
		||||
        # targets
 | 
			
		||||
        batch_idx = batch["batch_idx"].view(-1, 1)
 | 
			
		||||
        targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
 | 
			
		||||
        batch_idx = batch['batch_idx'].view(-1, 1)
 | 
			
		||||
        targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
 | 
			
		||||
        targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
 | 
			
		||||
        gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy
 | 
			
		||||
        mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
 | 
			
		||||
 | 
			
		||||
        masks = batch["masks"].to(self.device).float()
 | 
			
		||||
        masks = batch['masks'].to(self.device).float()
 | 
			
		||||
        if tuple(masks.shape[-2:]) != (mask_h, mask_w):  # downsample
 | 
			
		||||
            masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
 | 
			
		||||
            masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
 | 
			
		||||
 | 
			
		||||
        # pboxes
 | 
			
		||||
        pred_bboxes = self.bbox_decode(anchor_points, pred_distri)  # xyxy, (b, h*w, 4)
 | 
			
		||||
@ -135,13 +135,13 @@ class SegLoss(Loss):
 | 
			
		||||
    def single_mask_loss(self, gt_mask, pred, proto, xyxy, area):
 | 
			
		||||
        # Mask loss for one image
 | 
			
		||||
        pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:])  # (n, 32) @ (32,80,80) -> (n,80,80)
 | 
			
		||||
        loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
 | 
			
		||||
        loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
 | 
			
		||||
        return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def train(cfg=DEFAULT_CFG, use_python=False):
 | 
			
		||||
    model = cfg.model or "yolov8n-seg.pt"
 | 
			
		||||
    data = cfg.data or "coco128-seg.yaml"  # or yolo.ClassificationDataset("mnist")
 | 
			
		||||
    model = cfg.model or 'yolov8n-seg.pt'
 | 
			
		||||
    data = cfg.data or 'coco128-seg.yaml'  # or yolo.ClassificationDataset("mnist")
 | 
			
		||||
    device = cfg.device if cfg.device is not None else ''
 | 
			
		||||
 | 
			
		||||
    args = dict(model=model, data=data, device=device)
 | 
			
		||||
@ -153,5 +153,5 @@ def train(cfg=DEFAULT_CFG, use_python=False):
 | 
			
		||||
        trainer.train()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    train()
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,7 @@ class SegmentationValidator(DetectionValidator):
 | 
			
		||||
 | 
			
		||||
    def preprocess(self, batch):
 | 
			
		||||
        batch = super().preprocess(batch)
 | 
			
		||||
        batch["masks"] = batch["masks"].to(self.device).float()
 | 
			
		||||
        batch['masks'] = batch['masks'].to(self.device).float()
 | 
			
		||||
        return batch
 | 
			
		||||
 | 
			
		||||
    def init_metrics(self, model):
 | 
			
		||||
@ -37,8 +37,8 @@ class SegmentationValidator(DetectionValidator):
 | 
			
		||||
            self.process = ops.process_mask  # faster
 | 
			
		||||
 | 
			
		||||
    def get_desc(self):
 | 
			
		||||
        return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P",
 | 
			
		||||
                                         "R", "mAP50", "mAP50-95)")
 | 
			
		||||
        return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P',
 | 
			
		||||
                                         'R', 'mAP50', 'mAP50-95)')
 | 
			
		||||
 | 
			
		||||
    def postprocess(self, preds):
 | 
			
		||||
        p = ops.non_max_suppression(preds[0],
 | 
			
		||||
@ -55,11 +55,11 @@ class SegmentationValidator(DetectionValidator):
 | 
			
		||||
    def update_metrics(self, preds, batch):
 | 
			
		||||
        # Metrics
 | 
			
		||||
        for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
 | 
			
		||||
            idx = batch["batch_idx"] == si
 | 
			
		||||
            cls = batch["cls"][idx]
 | 
			
		||||
            bbox = batch["bboxes"][idx]
 | 
			
		||||
            idx = batch['batch_idx'] == si
 | 
			
		||||
            cls = batch['cls'][idx]
 | 
			
		||||
            bbox = batch['bboxes'][idx]
 | 
			
		||||
            nl, npr = cls.shape[0], pred.shape[0]  # number of labels, predictions
 | 
			
		||||
            shape = batch["ori_shape"][si]
 | 
			
		||||
            shape = batch['ori_shape'][si]
 | 
			
		||||
            correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device)  # init
 | 
			
		||||
            correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device)  # init
 | 
			
		||||
            self.seen += 1
 | 
			
		||||
@ -74,23 +74,23 @@ class SegmentationValidator(DetectionValidator):
 | 
			
		||||
 | 
			
		||||
            # Masks
 | 
			
		||||
            midx = [si] if self.args.overlap_mask else idx
 | 
			
		||||
            gt_masks = batch["masks"][midx]
 | 
			
		||||
            pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch["img"][si].shape[1:])
 | 
			
		||||
            gt_masks = batch['masks'][midx]
 | 
			
		||||
            pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch['img'][si].shape[1:])
 | 
			
		||||
 | 
			
		||||
            # Predictions
 | 
			
		||||
            if self.args.single_cls:
 | 
			
		||||
                pred[:, 5] = 0
 | 
			
		||||
            predn = pred.clone()
 | 
			
		||||
            ops.scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape,
 | 
			
		||||
                            ratio_pad=batch["ratio_pad"][si])  # native-space pred
 | 
			
		||||
            ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
 | 
			
		||||
                            ratio_pad=batch['ratio_pad'][si])  # native-space pred
 | 
			
		||||
 | 
			
		||||
            # Evaluate
 | 
			
		||||
            if nl:
 | 
			
		||||
                height, width = batch["img"].shape[2:]
 | 
			
		||||
                height, width = batch['img'].shape[2:]
 | 
			
		||||
                tbox = ops.xywh2xyxy(bbox) * torch.tensor(
 | 
			
		||||
                    (width, height, width, height), device=self.device)  # target boxes
 | 
			
		||||
                ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape,
 | 
			
		||||
                                ratio_pad=batch["ratio_pad"][si])  # native-space labels
 | 
			
		||||
                ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
 | 
			
		||||
                                ratio_pad=batch['ratio_pad'][si])  # native-space labels
 | 
			
		||||
                labelsn = torch.cat((cls, tbox), 1)  # native-space labels
 | 
			
		||||
                correct_bboxes = self._process_batch(predn, labelsn)
 | 
			
		||||
                # TODO: maybe remove these `self.` arguments as they already are member variable
 | 
			
		||||
@ -112,11 +112,11 @@ class SegmentationValidator(DetectionValidator):
 | 
			
		||||
 | 
			
		||||
            # Save
 | 
			
		||||
            if self.args.save_json:
 | 
			
		||||
                pred_masks = ops.scale_image(batch["img"][si].shape[1:],
 | 
			
		||||
                pred_masks = ops.scale_image(batch['img'][si].shape[1:],
 | 
			
		||||
                                             pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
 | 
			
		||||
                                             shape,
 | 
			
		||||
                                             ratio_pad=batch["ratio_pad"][si])
 | 
			
		||||
                self.pred_to_json(predn, batch["im_file"][si], pred_masks)
 | 
			
		||||
                                             ratio_pad=batch['ratio_pad'][si])
 | 
			
		||||
                self.pred_to_json(predn, batch['im_file'][si], pred_masks)
 | 
			
		||||
            # if self.args.save_txt:
 | 
			
		||||
            #    save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
 | 
			
		||||
 | 
			
		||||
@ -136,7 +136,7 @@ class SegmentationValidator(DetectionValidator):
 | 
			
		||||
                gt_masks = gt_masks.repeat(nl, 1, 1)  # shape(1,640,640) -> (n,640,640)
 | 
			
		||||
                gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
 | 
			
		||||
            if gt_masks.shape[1:] != pred_masks.shape[1:]:
 | 
			
		||||
                gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
 | 
			
		||||
                gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode='bilinear', align_corners=False)[0]
 | 
			
		||||
                gt_masks = gt_masks.gt_(0.5)
 | 
			
		||||
            iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
 | 
			
		||||
        else:  # boxes
 | 
			
		||||
@ -158,20 +158,20 @@ class SegmentationValidator(DetectionValidator):
 | 
			
		||||
        return torch.tensor(correct, dtype=torch.bool, device=detections.device)
 | 
			
		||||
 | 
			
		||||
    def plot_val_samples(self, batch, ni):
 | 
			
		||||
        plot_images(batch["img"],
 | 
			
		||||
                    batch["batch_idx"],
 | 
			
		||||
                    batch["cls"].squeeze(-1),
 | 
			
		||||
                    batch["bboxes"],
 | 
			
		||||
                    batch["masks"],
 | 
			
		||||
                    paths=batch["im_file"],
 | 
			
		||||
                    fname=self.save_dir / f"val_batch{ni}_labels.jpg",
 | 
			
		||||
        plot_images(batch['img'],
 | 
			
		||||
                    batch['batch_idx'],
 | 
			
		||||
                    batch['cls'].squeeze(-1),
 | 
			
		||||
                    batch['bboxes'],
 | 
			
		||||
                    batch['masks'],
 | 
			
		||||
                    paths=batch['im_file'],
 | 
			
		||||
                    fname=self.save_dir / f'val_batch{ni}_labels.jpg',
 | 
			
		||||
                    names=self.names)
 | 
			
		||||
 | 
			
		||||
    def plot_predictions(self, batch, preds, ni):
 | 
			
		||||
        plot_images(batch["img"],
 | 
			
		||||
        plot_images(batch['img'],
 | 
			
		||||
                    *output_to_target(preds[0], max_det=15),
 | 
			
		||||
                    torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
 | 
			
		||||
                    paths=batch["im_file"],
 | 
			
		||||
                    paths=batch['im_file'],
 | 
			
		||||
                    fname=self.save_dir / f'val_batch{ni}_pred.jpg',
 | 
			
		||||
                    names=self.names)  # pred
 | 
			
		||||
        self.plot_masks.clear()
 | 
			
		||||
@ -182,8 +182,8 @@ class SegmentationValidator(DetectionValidator):
 | 
			
		||||
        from pycocotools.mask import encode  # noqa
 | 
			
		||||
 | 
			
		||||
        def single_encode(x):
 | 
			
		||||
            rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
 | 
			
		||||
            rle["counts"] = rle["counts"].decode("utf-8")
 | 
			
		||||
            rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0]
 | 
			
		||||
            rle['counts'] = rle['counts'].decode('utf-8')
 | 
			
		||||
            return rle
 | 
			
		||||
 | 
			
		||||
        stem = Path(filename).stem
 | 
			
		||||
@ -203,8 +203,8 @@ class SegmentationValidator(DetectionValidator):
 | 
			
		||||
 | 
			
		||||
    def eval_json(self, stats):
 | 
			
		||||
        if self.args.save_json and self.is_coco and len(self.jdict):
 | 
			
		||||
            anno_json = self.data['path'] / "annotations/instances_val2017.json"  # annotations
 | 
			
		||||
            pred_json = self.save_dir / "predictions.json"  # predictions
 | 
			
		||||
            anno_json = self.data['path'] / 'annotations/instances_val2017.json'  # annotations
 | 
			
		||||
            pred_json = self.save_dir / 'predictions.json'  # predictions
 | 
			
		||||
            self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
 | 
			
		||||
            try:  # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
 | 
			
		||||
                check_requirements('pycocotools>=2.0.6')
 | 
			
		||||
@ -212,7 +212,7 @@ class SegmentationValidator(DetectionValidator):
 | 
			
		||||
                from pycocotools.cocoeval import COCOeval  # noqa
 | 
			
		||||
 | 
			
		||||
                for x in anno_json, pred_json:
 | 
			
		||||
                    assert x.is_file(), f"{x} file not found"
 | 
			
		||||
                    assert x.is_file(), f'{x} file not found'
 | 
			
		||||
                anno = COCO(str(anno_json))  # init annotations api
 | 
			
		||||
                pred = anno.loadRes(str(pred_json))  # init predictions api (must pass string, not Path)
 | 
			
		||||
                for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]):
 | 
			
		||||
@ -231,8 +231,8 @@ class SegmentationValidator(DetectionValidator):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def val(cfg=DEFAULT_CFG, use_python=False):
 | 
			
		||||
    model = cfg.model or "yolov8n-seg.pt"
 | 
			
		||||
    data = cfg.data or "coco128-seg.yaml"
 | 
			
		||||
    model = cfg.model or 'yolov8n-seg.pt'
 | 
			
		||||
    data = cfg.data or 'coco128-seg.yaml'
 | 
			
		||||
 | 
			
		||||
    args = dict(model=model, data=data)
 | 
			
		||||
    if use_python:
 | 
			
		||||
@ -243,5 +243,5 @@ def val(cfg=DEFAULT_CFG, use_python=False):
 | 
			
		||||
        validator(model=args['model'])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    val()
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user