ultralytics 8.0.50
AMP check and YOLOv5u YAMLs (#1263)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Troy <wudashuo@vip.qq.com> Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Huijae Lee <46982469+ZeroAct@users.noreply.github.com>
This commit is contained in:
@ -574,7 +574,7 @@ class Exporter:
|
||||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||
saved_model = Path(str(self.file).replace(self.file.suffix, '_saved_model'))
|
||||
if self.args.int8:
|
||||
f = saved_model / (self.file.stem + 'yolov8n_integer_quant.tflite') # fp32 in/out
|
||||
f = saved_model / (self.file.stem + '_integer_quant.tflite') # fp32 in/out
|
||||
elif self.args.half:
|
||||
f = saved_model / (self.file.stem + '_float16.tflite')
|
||||
else:
|
||||
@ -863,18 +863,6 @@ def export(cfg=DEFAULT_CFG):
|
||||
cfg.model = cfg.model or 'yolov8n.yaml'
|
||||
cfg.format = cfg.format or 'torchscript'
|
||||
|
||||
# exporter = Exporter(cfg)
|
||||
#
|
||||
# model = None
|
||||
# if isinstance(cfg.model, (str, Path)):
|
||||
# if Path(cfg.model).suffix == '.yaml':
|
||||
# model = DetectionModel(cfg.model)
|
||||
# elif Path(cfg.model).suffix == '.pt':
|
||||
# model = attempt_load_weights(cfg.model, fuse=True)
|
||||
# else:
|
||||
# TypeError(f'Unsupported model type {cfg.model}')
|
||||
# exporter(model=model)
|
||||
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(cfg.model)
|
||||
model.export(**vars(cfg))
|
||||
|
@ -203,6 +203,8 @@ class YOLO:
|
||||
if source is None:
|
||||
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
||||
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and \
|
||||
('predict' in sys.argv or 'mode=predict' in sys.argv)
|
||||
|
||||
overrides = self.overrides.copy()
|
||||
overrides['conf'] = 0.25
|
||||
@ -213,10 +215,9 @@ class YOLO:
|
||||
if not self.predictor:
|
||||
self.task = overrides.get('task') or self.task
|
||||
self.predictor = TASK_MAP[self.task][3](overrides=overrides)
|
||||
self.predictor.setup_model(model=self.model)
|
||||
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
||||
else: # only update args if predictor is already setup
|
||||
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
||||
is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')
|
||||
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
||||
|
||||
def track(self, source=None, stream=False, **kwargs):
|
||||
|
@ -183,6 +183,8 @@ class BasePredictor:
|
||||
'preprocess': self.dt[0].dt * 1E3 / n,
|
||||
'inference': self.dt[1].dt * 1E3 / n,
|
||||
'postprocess': self.dt[2].dt * 1E3 / n}
|
||||
if self.source_type.tensor: # skip write, show and plot operations if input is raw tensor
|
||||
continue
|
||||
p, im0 = (path[i], im0s[i].copy()) if self.source_type.webcam or self.source_type.from_img \
|
||||
else (path, im0s.copy())
|
||||
p = Path(p)
|
||||
@ -218,11 +220,16 @@ class BasePredictor:
|
||||
|
||||
self.run_callbacks('on_predict_end')
|
||||
|
||||
def setup_model(self, model):
|
||||
device = select_device(self.args.device)
|
||||
def setup_model(self, model, verbose=True):
|
||||
device = select_device(self.args.device, verbose=verbose)
|
||||
model = model or self.args.model
|
||||
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
|
||||
self.model = AutoBackend(model, device=device, dnn=self.args.dnn, data=self.args.data, fp16=self.args.half)
|
||||
self.model = AutoBackend(model,
|
||||
device=device,
|
||||
dnn=self.args.dnn,
|
||||
data=self.args.data,
|
||||
fp16=self.args.half,
|
||||
verbose=verbose)
|
||||
self.device = device
|
||||
self.model.eval()
|
||||
|
||||
|
@ -25,8 +25,8 @@ from tqdm import tqdm
|
||||
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks,
|
||||
colorstr, emojis, yaml_save)
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, ONLINE, RANK, ROOT, SETTINGS, TQDM_BAR_FORMAT, __version__,
|
||||
callbacks, colorstr, emojis, yaml_save)
|
||||
from ultralytics.yolo.utils.autobatch import check_train_batch_size
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
|
||||
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
|
||||
@ -111,8 +111,6 @@ class BaseTrainer:
|
||||
print_args(vars(self.args))
|
||||
|
||||
# Device
|
||||
self.amp = self.device.type != 'cpu'
|
||||
self.scaler = amp.GradScaler(enabled=self.amp)
|
||||
if self.device.type == 'cpu':
|
||||
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
||||
|
||||
@ -126,7 +124,7 @@ class BaseTrainer:
|
||||
if 'yaml_file' in self.data:
|
||||
self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' error ❌ {e}")) from e
|
||||
raise RuntimeError(emojis(f"Dataset '{self.args.data}' error ❌ {e}")) from e
|
||||
|
||||
self.trainset, self.testset = self.get_dataset(self.data)
|
||||
self.ema = None
|
||||
@ -204,6 +202,8 @@ class BaseTrainer:
|
||||
ckpt = self.setup_model()
|
||||
self.model = self.model.to(self.device)
|
||||
self.set_model_attributes()
|
||||
self.amp = check_amp(self.model)
|
||||
self.scaler = amp.GradScaler(enabled=self.amp)
|
||||
if world_size > 1:
|
||||
self.model = DDP(self.model, device_ids=[rank])
|
||||
# Check imgsz
|
||||
@ -597,3 +597,31 @@ class BaseTrainer:
|
||||
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
|
||||
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias')
|
||||
return optimizer
|
||||
|
||||
|
||||
def check_amp(model):
|
||||
# Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
|
||||
device = next(model.parameters()).device # get model device
|
||||
if device.type in ('cpu', 'mps'):
|
||||
return False # AMP only used on CUDA devices
|
||||
|
||||
def amp_allclose(m, im):
|
||||
# All close FP32 vs AMP results
|
||||
a = m(im, device=device, verbose=False)[0].boxes.boxes # FP32 inference
|
||||
with torch.cuda.amp.autocast(True):
|
||||
b = m(im, device=device, verbose=False)[0].boxes.boxes # AMP inference
|
||||
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.1) # close to 10% absolute tolerance
|
||||
|
||||
f = ROOT / 'assets/bus.jpg' # image to check
|
||||
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if ONLINE else np.ones((640, 640, 3))
|
||||
prefix = colorstr('AMP: ')
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...')
|
||||
assert amp_allclose(YOLO('yolov8n.pt'), im)
|
||||
LOGGER.info(f'{prefix}checks passed ✅')
|
||||
return True
|
||||
except AssertionError:
|
||||
LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to '
|
||||
f'NaN losses or zero-mAP results, so AMP will be disabled during training.')
|
||||
return False
|
||||
|
Reference in New Issue
Block a user