General `ultralytics==8.0.6` updates (#351)

Co-authored-by: Dzmitry Plashchynski <plashchynski@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 70427579b8
commit f8e32c4c13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -84,22 +84,22 @@ jobs:
- name: Test detection - name: Test detection
shell: bash # for Windows compatibility shell: bash # for Windows compatibility
run: | run: |
yolo task=detect mode=train model=yolov8n.yaml data=coco8.yaml epochs=1 imgsz=32 yolo task=detect mode=train data=coco8.yaml model=yolov8n.yaml epochs=1 imgsz=32
yolo task=detect mode=val model=runs/detect/train/weights/last.pt imgsz=32 yolo task=detect mode=val data=coco8.yaml model=runs/detect/train/weights/last.pt imgsz=32
yolo task=detect mode=predict model=runs/detect/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg yolo task=detect mode=predict model=runs/detect/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg
yolo mode=export model=runs/detect/train/weights/last.pt imgsz=32 format=torchscript yolo mode=export model=runs/detect/train/weights/last.pt imgsz=32 format=torchscript
- name: Test segmentation - name: Test segmentation
shell: bash # for Windows compatibility shell: bash # for Windows compatibility
run: | run: |
yolo task=segment mode=train model=yolov8n-seg.yaml data=coco8-seg.yaml epochs=1 imgsz=32 yolo task=segment mode=train data=coco8-seg.yaml model=yolov8n-seg.yaml epochs=1 imgsz=32
yolo task=segment mode=val model=runs/segment/train/weights/last.pt data=coco8-seg.yaml imgsz=32 yolo task=segment mode=val data=coco8-seg.yaml model=runs/segment/train/weights/last.pt imgsz=32
yolo task=segment mode=predict model=runs/segment/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg yolo task=segment mode=predict model=runs/segment/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg
yolo mode=export model=runs/segment/train/weights/last.pt imgsz=32 format=torchscript yolo mode=export model=runs/segment/train/weights/last.pt imgsz=32 format=torchscript
- name: Test classification - name: Test classification
shell: bash # for Windows compatibility shell: bash # for Windows compatibility
run: | run: |
yolo task=classify mode=train model=yolov8n-cls.yaml data=mnist160 epochs=1 imgsz=32 yolo task=classify mode=train data=mnist160 model=yolov8n-cls.yaml epochs=1 imgsz=32
yolo task=classify mode=val model=runs/classify/train/weights/last.pt data=mnist160 imgsz=32 yolo task=classify mode=val data=mnist160 model=runs/classify/train/weights/last.pt imgsz=32
yolo task=classify mode=predict model=runs/classify/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg yolo task=classify mode=predict model=runs/classify/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg
yolo mode=export model=runs/classify/train/weights/last.pt imgsz=32 format=torchscript yolo mode=export model=runs/classify/train/weights/last.pt imgsz=32 format=torchscript
- name: Pytest tests - name: Pytest tests

@ -52,7 +52,7 @@ ENV OMP_NUM_THREADS=1
# t=ultralytics/ultralytics:latest tnew=ultralytics/ultralytics:v6.2 && sudo docker pull $t && sudo docker tag $t $tnew && sudo docker push $tnew # t=ultralytics/ultralytics:latest tnew=ultralytics/ultralytics:v6.2 && sudo docker pull $t && sudo docker tag $t $tnew && sudo docker push $tnew
# Clean up # Clean up
# docker system prune -a --volumes # sudo docker system prune -a --volumes
# Update Ubuntu drivers # Update Ubuntu drivers
# https://www.maketecheasier.com/install-nvidia-drivers-ubuntu/ # https://www.maketecheasier.com/install-nvidia-drivers-ubuntu/

@ -1,13 +1,16 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
from pathlib import Path
from ultralytics.yolo.configs import get_config from ultralytics.yolo.configs import get_config
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, SETTINGS
from ultralytics.yolo.v8 import classify, detect, segment from ultralytics.yolo.v8 import classify, detect, segment
CFG_DET = 'yolov8n.yaml' CFG_DET = 'yolov8n.yaml'
CFG_SEG = 'yolov8n-seg.yaml' CFG_SEG = 'yolov8n-seg.yaml'
CFG_CLS = 'squeezenet1_0' CFG_CLS = 'squeezenet1_0'
CFG = get_config(DEFAULT_CONFIG) CFG = get_config(DEFAULT_CONFIG)
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n'
SOURCE = ROOT / "assets" SOURCE = ROOT / "assets"
@ -18,15 +21,14 @@ def test_detect():
# Trainer # Trainer
trainer = detect.DetectionTrainer(overrides=overrides) trainer = detect.DetectionTrainer(overrides=overrides)
trainer.train() trainer.train()
trained_model = trainer.best
# Validator # Validator
val = detect.DetectionValidator(args=CFG) val = detect.DetectionValidator(args=CFG)
val(model=trained_model) val(model=trainer.best) # validate best.pt
# Predictor # Predictor
pred = detect.DetectionPredictor(overrides={"imgsz": [64, 64]}) pred = detect.DetectionPredictor(overrides={"imgsz": [64, 64]})
result = pred(source=SOURCE, model="yolov8n.pt", return_outputs=True) result = pred(source=SOURCE, model=f"{MODEL}.pt", return_outputs=True)
assert len(list(result)), "predictor test failed" assert len(list(result)), "predictor test failed"
overrides["resume"] = trainer.last overrides["resume"] = trainer.last
@ -49,15 +51,14 @@ def test_segment():
# trainer # trainer
trainer = segment.SegmentationTrainer(overrides=overrides) trainer = segment.SegmentationTrainer(overrides=overrides)
trainer.train() trainer.train()
trained_model = trainer.best
# Validator # Validator
val = segment.SegmentationValidator(args=CFG) val = segment.SegmentationValidator(args=CFG)
val(model=trained_model) val(model=trainer.best) # validate best.pt
# Predictor # Predictor
pred = segment.SegmentationPredictor(overrides={"imgsz": [64, 64]}) pred = segment.SegmentationPredictor(overrides={"imgsz": [64, 64]})
result = pred(source=SOURCE, model="yolov8n-seg.pt", return_outputs=True) result = pred(source=SOURCE, model=f"{MODEL}-seg.pt", return_outputs=True)
assert len(list(result)) == 2, "predictor test failed" assert len(list(result)) == 2, "predictor test failed"
# Test resume # Test resume
@ -82,13 +83,12 @@ def test_classify():
# Trainer # Trainer
trainer = classify.ClassificationTrainer(overrides=overrides) trainer = classify.ClassificationTrainer(overrides=overrides)
trainer.train() trainer.train()
trained_model = trainer.best
# Validator # Validator
val = classify.ClassificationValidator(args=CFG) val = classify.ClassificationValidator(args=CFG)
val(model=trained_model) val(model=trainer.best)
# Predictor # Predictor
pred = classify.ClassificationPredictor(overrides={"imgsz": [64, 64]}) pred = classify.ClassificationPredictor(overrides={"imgsz": [64, 64]})
result = pred(source=SOURCE, model=trained_model, return_outputs=True) result = pred(source=SOURCE, model=trainer.best, return_outputs=True)
assert len(list(result)) == 2, "predictor test failed" assert len(list(result)) == 2, "predictor test failed"

@ -1,7 +1,5 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import signal
import sys
from pathlib import Path from pathlib import Path
from time import sleep from time import sleep
@ -15,19 +13,21 @@ AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__versio
session = None session = None
# Causing problems in tests (non-authenticated)
def signal_handler(signum, frame): # import signal
""" Confirm exit """ # import sys
global hub_logger # def signal_handler(signum, frame):
LOGGER.info(f'Signal received. {signum} {frame}') # """ Confirm exit """
if isinstance(session, HubTrainingSession): # global hub_logger
hub_logger.alive = False # LOGGER.info(f'Signal received. {signum} {frame}')
del hub_logger # if isinstance(session, HubTrainingSession):
sys.exit(signum) # hub_logger.alive = False
# del hub_logger
# sys.exit(signum)
signal.signal(signal.SIGTERM, signal_handler) #
signal.signal(signal.SIGINT, signal_handler) #
# signal.signal(signal.SIGTERM, signal_handler)
# signal.signal(signal.SIGINT, signal_handler)
class HubTrainingSession: class HubTrainingSession:

@ -8,13 +8,13 @@ from omegaconf import DictConfig, OmegaConf
from ultralytics.yolo.configs.hydra_patch import check_config_mismatch from ultralytics.yolo.configs.hydra_patch import check_config_mismatch
def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict] = None): def get_config(config: Union[str, Path, DictConfig], overrides: Union[str, Dict] = None):
""" """
Load and merge configuration data from a file or dictionary. Load and merge configuration data from a file or dictionary.
Args: Args:
config (Union[str, DictConfig]): Configuration data in the form of a file name or a DictConfig object. config (str) or (Path) or (DictConfig): Configuration data in the form of a file name or a DictConfig object.
overrides (Union[str, Dict], optional): Overrides in the form of a file name or a dictionary. Default is None. overrides (str) or(Dict), optional: Overrides in the form of a file name or a dictionary. Default is None.
Returns: Returns:
OmegaConf.Namespace: Training arguments namespace. OmegaConf.Namespace: Training arguments namespace.

@ -14,12 +14,11 @@ import numpy as np
import torch import torch
from PIL import ExifTags, Image, ImageOps from PIL import ExifTags, Image, ImageOps
from ultralytics.yolo.utils import LOGGER, ROOT, colorstr, yaml_load from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, ROOT, colorstr, yaml_load
from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii
from ultralytics.yolo.utils.downloads import download from ultralytics.yolo.utils.downloads import download
from ultralytics.yolo.utils.files import unzip_file from ultralytics.yolo.utils.files import unzip_file
from ultralytics.yolo.utils.ops import segments2boxes
from ..utils.ops import segments2boxes
HELP_URL = "See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data" HELP_URL = "See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data"
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixes IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixes
@ -173,12 +172,7 @@ def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
areas = [] areas = []
ms = [] ms = []
for si in range(len(segments)): for si in range(len(segments)):
mask = polygon2mask( mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1)
imgsz,
[segments[si].reshape(-1)],
downsample_ratio=downsample_ratio,
color=1,
)
ms.append(mask) ms.append(mask)
areas.append(mask.sum()) areas.append(mask.sum())
areas = np.asarray(areas) areas = np.asarray(areas)
@ -194,13 +188,14 @@ def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
def check_dataset_yaml(data, autodownload=True): def check_dataset_yaml(data, autodownload=True):
# Download, check and/or unzip dataset if not found locally # Download, check and/or unzip dataset if not found locally
data = check_file(data) data = check_file(data)
DATASETS_DIR = (Path.cwd() / "../datasets").resolve() # TODO: handle global dataset dir
# Download (optional) # Download (optional)
extract_dir = '' extract_dir = ''
if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)): if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)):
download(data, dir=f'{DATASETS_DIR}/{Path(data).stem}', unzip=True, delete=False, curl=False, threads=1) download(data, dir=f'{DATASETS_DIR}/{Path(data).stem}', unzip=True, delete=False, curl=False, threads=1)
data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml')) data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
extract_dir, autodownload = data.parent, False extract_dir, autodownload = data.parent, False
# Read yaml (optional) # Read yaml (optional)
if isinstance(data, (str, Path)): if isinstance(data, (str, Path)):
data = yaml_load(data, append_filename=True) # dictionary data = yaml_load(data, append_filename=True) # dictionary
@ -215,7 +210,7 @@ def check_dataset_yaml(data, autodownload=True):
# Resolve paths # Resolve paths
path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.' path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
if not path.is_absolute(): if not path.is_absolute():
path = (Path.cwd() / path).resolve() path = (DATASETS_DIR / path).resolve()
data['path'] = path # download scripts data['path'] = path # download scripts
for k in 'train', 'val', 'test': for k in 'train', 'val', 'test':
if data.get(k): # prepend path if data.get(k): # prepend path
@ -253,6 +248,7 @@ def check_dataset_yaml(data, autodownload=True):
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt}" s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt}"
LOGGER.info(f"Dataset download {s}") LOGGER.info(f"Dataset download {s}")
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
return data # dictionary return data # dictionary
@ -274,12 +270,12 @@ def check_dataset(dataset: str):
'nc': Number of classes in the dataset 'nc': Number of classes in the dataset
'names': List of class names in the dataset 'names': List of class names in the dataset
""" """
data_dir = (Path.cwd() / "datasets" / dataset).resolve() data_dir = (DATASETS_DIR / dataset).resolve()
if not data_dir.is_dir(): if not data_dir.is_dir():
LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...') LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
t = time.time() t = time.time()
if dataset == 'imagenet': if dataset == 'imagenet':
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True) subprocess.run(f"bash {ROOT / 'yolo/data/scripts/get_imagenet.sh'}", shell=True, check=True)
else: else:
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip' url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
download(url, dir=data_dir.parent) download(url, dir=data_dir.parent)

@ -240,7 +240,7 @@ class BasePredictor:
if isinstance(self.vid_writer[idx], cv2.VideoWriter): if isinstance(self.vid_writer[idx], cv2.VideoWriter):
self.vid_writer[idx].release() # release previous video writer self.vid_writer[idx].release() # release previous video writer
if vid_cap: # video if vid_cap: # video
fps = vid_cap.get(cv2.CAP_PROP_FPS) fps = int(vid_cap.get(cv2.CAP_PROP_FPS)) # integer required, floats produce error in MP4 codec
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
else: # stream else: # stream

@ -506,9 +506,11 @@ class BaseTrainer:
def check_resume(self): def check_resume(self):
resume = self.args.resume resume = self.args.resume
if resume: if resume:
last = Path(check_file(resume) if isinstance(resume, str) else get_latest_run()) last = Path(check_file(resume) if isinstance(resume, (str, Path)) else get_latest_run())
args_yaml = last.parent.parent / 'args.yaml' # train options yaml args_yaml = last.parent.parent / 'args.yaml' # train options yaml
if args_yaml.is_file(): assert args_yaml.is_file(), \
FileNotFoundError('Resume checkpoint f{last} not found. '
'Please pass a valid checkpoint to resume from, i.e. yolo resume=path/to/last.pt')
args = get_config(args_yaml) # replace args = get_config(args_yaml) # replace
args.model, resume = str(last), True # reinstate args.model, resume = str(last), True # reinstate
self.args = args self.args = args

@ -187,7 +187,7 @@ def get_git_root_dir():
""" """
try: try:
output = subprocess.run(["git", "rev-parse", "--git-dir"], capture_output=True, check=True) output = subprocess.run(["git", "rev-parse", "--git-dir"], capture_output=True, check=True)
return Path(output.stdout.strip().decode('utf-8')).parent # parent/.git return Path(output.stdout.strip().decode('utf-8')).parent.resolve() # parent/.git
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
return None return None
@ -348,16 +348,18 @@ def yaml_load(file='data.yaml', append_filename=False):
return {**yaml.safe_load(f), 'yaml_file': str(file)} if append_filename else yaml.safe_load(f) return {**yaml.safe_load(f), 'yaml_file': str(file)} if append_filename else yaml.safe_load(f)
def get_settings(file=USER_CONFIG_DIR / 'settings.yaml'): def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.0'):
""" """
Loads a global settings YAML file or creates one with default values if it does not exist. Loads a global Ultralytics settings YAML file or creates one with default values if it does not exist.
Args: Args:
file (Path): Path to the settings YAML file. Defaults to 'settings.yaml' in the USER_CONFIG_DIR. file (Path): Path to the Ultralytics settings YAML file. Defaults to 'settings.yaml' in the USER_CONFIG_DIR.
version (str): Settings version. If min settings version not met, new default settings will be saved.
Returns: Returns:
dict: Dictionary of settings key-value pairs. dict: Dictionary of settings key-value pairs.
""" """
from ultralytics.yolo.utils.checks import check_version
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
root = get_git_root_dir() or Path('') # not is_pip_package() root = get_git_root_dir() or Path('') # not is_pip_package()
@ -366,7 +368,8 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml'):
'weights_dir': str(root / 'weights'), # default weights directory. 'weights_dir': str(root / 'weights'), # default weights directory.
'runs_dir': str(root / 'runs'), # default runs directory. 'runs_dir': str(root / 'runs'), # default runs directory.
'sync': True, # sync analytics to help with YOLO development 'sync': True, # sync analytics to help with YOLO development
'uuid': uuid.getnode()} # device UUID to align analytics 'uuid': uuid.getnode(), # device UUID to align analytics
'settings_version': version} # Ultralytics settings version
with torch_distributed_zero_first(RANK): with torch_distributed_zero_first(RANK):
if not file.exists(): if not file.exists():
@ -375,12 +378,14 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml'):
settings = yaml_load(file) settings = yaml_load(file)
# Check that settings keys and types match defaults # Check that settings keys and types match defaults
correct = settings.keys() == defaults.keys() and \ correct = settings.keys() == defaults.keys() \
all(type(a) == type(b) for a, b in zip(settings.values(), defaults.values())) and all(type(a) == type(b) for a, b in zip(settings.values(), defaults.values())) \
and check_version(settings['settings_version'], version)
if not correct: if not correct:
LOGGER.warning('WARNING ⚠️ Different global settings detected, resetting to defaults. ' LOGGER.warning('WARNING ⚠️ Ultralytics settings reset to defaults. '
'This may be due to an ultralytics package update. ' '\nThis is normal and may be due to a recent ultralytics package update, '
f'View and update your global settings directly in {file}') 'but may have overwritten previous settings. '
f"\nYou may view and update settings directly in '{file}'")
settings = defaults # merge **defaults with **settings (prefer **settings) settings = defaults # merge **defaults with **settings (prefer **settings)
yaml_save(file, settings) # save updated defaults yaml_save(file, settings) # save updated defaults

@ -3,8 +3,6 @@
import json import json
from time import time from time import time
import torch
from ultralytics.hub.utils import PREFIX, sync_analytics from ultralytics.hub.utils import PREFIX, sync_analytics
from ultralytics.yolo.utils import LOGGER from ultralytics.yolo.utils import LOGGER

@ -252,7 +252,7 @@ class ConfusionMatrix:
vmin=0.0, vmin=0.0,
xticklabels=ticklabels, xticklabels=ticklabels,
yticklabels=ticklabels).set_facecolor((1, 1, 1)) yticklabels=ticklabels).set_facecolor((1, 1, 1))
ax.set_ylabel('True') ax.set_xlabel('True')
ax.set_ylabel('Predicted') ax.set_ylabel('Predicted')
ax.set_title('Confusion Matrix') ax.set_title('Confusion Matrix')
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250) fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)

@ -113,11 +113,10 @@ class ClassificationTrainer(BaseTrainer):
""" """
# Not needed for classification but necessary for segmentation & detection # Not needed for classification but necessary for segmentation & detection
keys = [f"{prefix}/{x}" for x in self.loss_names] keys = [f"{prefix}/{x}" for x in self.loss_names]
if loss_items is not None: if loss_items is None:
return keys
loss_items = [round(float(loss_items), 5)] loss_items = [round(float(loss_items), 5)]
return dict(zip(keys, loss_items)) return dict(zip(keys, loss_items))
else:
return keys
def resume_training(self, ckpt): def resume_training(self, ckpt):
pass pass

@ -48,14 +48,14 @@ class DetectionTrainer(BaseTrainer):
return batch return batch
def set_model_attributes(self): def set_model_attributes(self):
nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps) # nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)
self.args.box *= 3 / nl # scale to layers # self.args.box *= 3 / nl # scale to layers
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers # self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
self.model.nc = self.data["nc"] # attach number of classes to model self.model.nc = self.data["nc"] # attach number of classes to model
self.model.names = self.data["names"] # attach class names to model
self.model.args = self.args # attach hyperparameters to model self.model.args = self.args # attach hyperparameters to model
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
self.model.names = self.data["names"]
def get_model(self, cfg=None, weights=None, verbose=True): def get_model(self, cfg=None, weights=None, verbose=True):
model = DetectionModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose) model = DetectionModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose)

@ -6,8 +6,7 @@ import torch
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, ops from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, ops
from ultralytics.yolo.utils.checks import check_imgsz from ultralytics.yolo.utils.checks import check_imgsz
from ultralytics.yolo.utils.plotting import colors, save_one_box from ultralytics.yolo.utils.plotting import colors, save_one_box
from ultralytics.yolo.v8.detect.predict import DetectionPredictor
from ..detect.predict import DetectionPredictor
class SegmentationPredictor(DetectionPredictor): class SegmentationPredictor(DetectionPredictor):

@ -13,14 +13,15 @@ from ultralytics.yolo.utils.ops import crop_mask, xyxy2xywh
from ultralytics.yolo.utils.plotting import plot_images, plot_results from ultralytics.yolo.utils.plotting import plot_images, plot_results
from ultralytics.yolo.utils.tal import make_anchors from ultralytics.yolo.utils.tal import make_anchors
from ultralytics.yolo.utils.torch_utils import de_parallel from ultralytics.yolo.utils.torch_utils import de_parallel
from ultralytics.yolo.v8.detect.train import Loss
from ..detect.train import Loss
# BaseTrainer python usage # BaseTrainer python usage
class SegmentationTrainer(v8.detect.DetectionTrainer): class SegmentationTrainer(v8.detect.DetectionTrainer):
def __init__(self, config=DEFAULT_CONFIG, overrides={}): def __init__(self, config=DEFAULT_CONFIG, overrides=None):
if overrides is None:
overrides = {}
overrides["task"] = "segment" overrides["task"] = "segment"
super().__init__(config, overrides) super().__init__(config, overrides)

@ -13,8 +13,7 @@ from ultralytics.yolo.utils import DEFAULT_CONFIG, NUM_THREADS, ops
from ultralytics.yolo.utils.checks import check_requirements from ultralytics.yolo.utils.checks import check_requirements
from ultralytics.yolo.utils.metrics import ConfusionMatrix, SegmentMetrics, box_iou, mask_iou from ultralytics.yolo.utils.metrics import ConfusionMatrix, SegmentMetrics, box_iou, mask_iou
from ultralytics.yolo.utils.plotting import output_to_target, plot_images from ultralytics.yolo.utils.plotting import output_to_target, plot_images
from ultralytics.yolo.v8.detect import DetectionValidator
from ..detect import DetectionValidator
class SegmentationValidator(DetectionValidator): class SegmentationValidator(DetectionValidator):

Loading…
Cancel
Save