Update .pre-commit-config.yaml
(#1026)
This commit is contained in:
@ -144,7 +144,7 @@ class Exporter:
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, model=None):
|
||||
self.run_callbacks("on_export_start")
|
||||
self.run_callbacks('on_export_start')
|
||||
t = time.time()
|
||||
format = self.args.format.lower() # to lowercase
|
||||
if format in {'tensorrt', 'trt'}: # engine aliases
|
||||
@ -207,7 +207,7 @@ class Exporter:
|
||||
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
|
||||
self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
|
||||
self.metadata = {
|
||||
'description': f"Ultralytics {self.pretty_name} model trained on {self.args.data}",
|
||||
'description': f'Ultralytics {self.pretty_name} model trained on {self.args.data}',
|
||||
'author': 'Ultralytics',
|
||||
'license': 'GPL-3.0 https://ultralytics.com/license',
|
||||
'version': __version__,
|
||||
@ -215,7 +215,7 @@ class Exporter:
|
||||
'names': model.names} # model metadata
|
||||
|
||||
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and "
|
||||
f"output shape(s) {self.output_shape} ({file_size(file):.1f} MB)")
|
||||
f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)')
|
||||
|
||||
# Exports
|
||||
f = [''] * len(fmts) # exported filenames
|
||||
@ -259,15 +259,15 @@ class Exporter:
|
||||
s = '' if square else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " \
|
||||
f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
|
||||
imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(' ', '')
|
||||
data = f"data={self.args.data}" if model.task == 'segment' and format == 'pb' else ''
|
||||
data = f'data={self.args.data}' if model.task == 'segment' and format == 'pb' else ''
|
||||
LOGGER.info(
|
||||
f'\nExport complete ({time.time() - t:.1f}s)'
|
||||
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
||||
f"\nPredict: yolo task={model.task} mode=predict model={f} imgsz={imgsz} {data}"
|
||||
f"\nValidate: yolo task={model.task} mode=val model={f} imgsz={imgsz} data={self.args.data} {s}"
|
||||
f"\nVisualize: https://netron.app")
|
||||
f'\nPredict: yolo task={model.task} mode=predict model={f} imgsz={imgsz} {data}'
|
||||
f'\nValidate: yolo task={model.task} mode=val model={f} imgsz={imgsz} data={self.args.data} {s}'
|
||||
f'\nVisualize: https://netron.app')
|
||||
|
||||
self.run_callbacks("on_export_end")
|
||||
self.run_callbacks('on_export_end')
|
||||
return f # return list of exported files/dirs
|
||||
|
||||
@try_export
|
||||
@ -277,7 +277,7 @@ class Exporter:
|
||||
f = self.file.with_suffix('.torchscript')
|
||||
|
||||
ts = torch.jit.trace(self.model, self.im, strict=False)
|
||||
d = {"shape": self.im.shape, "stride": int(max(self.model.stride)), "names": self.model.names}
|
||||
d = {'shape': self.im.shape, 'stride': int(max(self.model.stride)), 'names': self.model.names}
|
||||
extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
|
||||
if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
|
||||
LOGGER.info(f'{prefix} optimizing for mobile...')
|
||||
@ -354,7 +354,7 @@ class Exporter:
|
||||
|
||||
ov_model = mo.convert_model(f_onnx,
|
||||
model_name=self.pretty_name,
|
||||
framework="onnx",
|
||||
framework='onnx',
|
||||
compress_to_fp16=self.args.half) # export
|
||||
ov.serialize(ov_model, f_ov) # save
|
||||
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
@ -471,7 +471,7 @@ class Exporter:
|
||||
if self.args.dynamic:
|
||||
shape = self.im.shape
|
||||
if shape[0] <= 1:
|
||||
LOGGER.warning(f"{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument")
|
||||
LOGGER.warning(f'{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument')
|
||||
profile = builder.create_optimization_profile()
|
||||
for inp in inputs:
|
||||
profile.set_shape(inp.name, (1, *shape[1:]), (max(1, shape[0] // 2), *shape[1:]), shape)
|
||||
@ -509,8 +509,8 @@ class Exporter:
|
||||
except ImportError:
|
||||
check_requirements(f"tensorflow{'' if CUDA else '-macos' if MACOS else '-cpu' if LINUX else ''}")
|
||||
import tensorflow as tf # noqa
|
||||
check_requirements(("onnx", "onnx2tf", "sng4onnx", "onnxsim", "onnx_graphsurgeon", "tflite_support"),
|
||||
cmds="--extra-index-url https://pypi.ngc.nvidia.com")
|
||||
check_requirements(('onnx', 'onnx2tf', 'sng4onnx', 'onnxsim', 'onnx_graphsurgeon', 'tflite_support'),
|
||||
cmds='--extra-index-url https://pypi.ngc.nvidia.com')
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||
f = str(self.file).replace(self.file.suffix, '_saved_model')
|
||||
@ -632,7 +632,7 @@ class Exporter:
|
||||
converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
|
||||
|
||||
tflite_model = converter.convert()
|
||||
open(f, "wb").write(tflite_model)
|
||||
open(f, 'wb').write(tflite_model)
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
@ -656,7 +656,7 @@ class Exporter:
|
||||
LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
|
||||
f = str(tflite_model).replace('.tflite', '_edgetpu.tflite') # Edge TPU model
|
||||
|
||||
cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {self.file.parent} {tflite_model}"
|
||||
cmd = f'edgetpu_compiler -s -d -k 10 --out_dir {self.file.parent} {tflite_model}'
|
||||
subprocess.run(cmd.split(), check=True)
|
||||
self._add_tflite_metadata(f)
|
||||
return f, None
|
||||
@ -707,8 +707,8 @@ class Exporter:
|
||||
|
||||
# Creates input info.
|
||||
input_meta = _metadata_fb.TensorMetadataT()
|
||||
input_meta.name = "image"
|
||||
input_meta.description = "Input image to be detected."
|
||||
input_meta.name = 'image'
|
||||
input_meta.description = 'Input image to be detected.'
|
||||
input_meta.content = _metadata_fb.ContentT()
|
||||
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
|
||||
input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB
|
||||
@ -716,8 +716,8 @@ class Exporter:
|
||||
|
||||
# Creates output info.
|
||||
output_meta = _metadata_fb.TensorMetadataT()
|
||||
output_meta.name = "output"
|
||||
output_meta.description = "Coordinates of detected objects, class labels, and confidence score."
|
||||
output_meta.name = 'output'
|
||||
output_meta.description = 'Coordinates of detected objects, class labels, and confidence score.'
|
||||
|
||||
# Label file
|
||||
tmp_file = Path('/tmp/meta.txt')
|
||||
@ -868,8 +868,8 @@ class Exporter:
|
||||
|
||||
|
||||
def export(cfg=DEFAULT_CFG):
|
||||
cfg.model = cfg.model or "yolov8n.yaml"
|
||||
cfg.format = cfg.format or "torchscript"
|
||||
cfg.model = cfg.model or 'yolov8n.yaml'
|
||||
cfg.format = cfg.format or 'torchscript'
|
||||
|
||||
# exporter = Exporter(cfg)
|
||||
#
|
||||
@ -888,7 +888,7 @@ def export(cfg=DEFAULT_CFG):
|
||||
model.export(**vars(cfg))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
CLI:
|
||||
yolo mode=export model=yolov8n.yaml format=onnx
|
||||
|
@ -16,13 +16,13 @@ from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
||||
|
||||
# Map head to model, trainer, validator, and predictor classes
|
||||
MODEL_MAP = {
|
||||
"classify": [
|
||||
'classify': [
|
||||
ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator',
|
||||
'yolo.TYPE.classify.ClassificationPredictor'],
|
||||
"detect": [
|
||||
'detect': [
|
||||
DetectionModel, 'yolo.TYPE.detect.DetectionTrainer', 'yolo.TYPE.detect.DetectionValidator',
|
||||
'yolo.TYPE.detect.DetectionPredictor'],
|
||||
"segment": [
|
||||
'segment': [
|
||||
SegmentationModel, 'yolo.TYPE.segment.SegmentationTrainer', 'yolo.TYPE.segment.SegmentationValidator',
|
||||
'yolo.TYPE.segment.SegmentationPredictor']}
|
||||
|
||||
@ -34,7 +34,7 @@ class YOLO:
|
||||
A python interface which emulates a model-like behaviour by wrapping trainers.
|
||||
"""
|
||||
|
||||
def __init__(self, model='yolov8n.pt', type="v8") -> None:
|
||||
def __init__(self, model='yolov8n.pt', type='v8') -> None:
|
||||
"""
|
||||
Initializes the YOLO object.
|
||||
|
||||
@ -94,7 +94,7 @@ class YOLO:
|
||||
suffix = Path(weights).suffix
|
||||
if suffix == '.pt':
|
||||
self.model, self.ckpt = attempt_load_one_weight(weights)
|
||||
self.task = self.model.args["task"]
|
||||
self.task = self.model.args['task']
|
||||
self.overrides = self.model.args
|
||||
self._reset_ckpt_args(self.overrides)
|
||||
else:
|
||||
@ -111,7 +111,7 @@ class YOLO:
|
||||
"""
|
||||
if not isinstance(self.model, nn.Module):
|
||||
raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
|
||||
f"PyTorch models can be used to train, val, predict and export, i.e. "
|
||||
f'PyTorch models can be used to train, val, predict and export, i.e. '
|
||||
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
|
||||
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
|
||||
|
||||
@ -155,11 +155,11 @@ class YOLO:
|
||||
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
|
||||
"""
|
||||
overrides = self.overrides.copy()
|
||||
overrides["conf"] = 0.25
|
||||
overrides['conf'] = 0.25
|
||||
overrides.update(kwargs)
|
||||
overrides["mode"] = kwargs.get("mode", "predict")
|
||||
assert overrides["mode"] in ['track', 'predict']
|
||||
overrides["save"] = kwargs.get("save", False) # not save files by default
|
||||
overrides['mode'] = kwargs.get('mode', 'predict')
|
||||
assert overrides['mode'] in ['track', 'predict']
|
||||
overrides['save'] = kwargs.get('save', False) # not save files by default
|
||||
if not self.predictor:
|
||||
self.predictor = self.PredictorClass(overrides=overrides)
|
||||
self.predictor.setup_model(model=self.model)
|
||||
@ -173,7 +173,7 @@ class YOLO:
|
||||
from ultralytics.tracker.track import register_tracker
|
||||
register_tracker(self)
|
||||
# bytetrack-based method needs low confidence predictions as input
|
||||
conf = kwargs.get("conf") or 0.1
|
||||
conf = kwargs.get('conf') or 0.1
|
||||
kwargs['conf'] = conf
|
||||
kwargs['mode'] = 'track'
|
||||
return self.predict(source=source, stream=stream, **kwargs)
|
||||
@ -188,9 +188,9 @@ class YOLO:
|
||||
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
||||
"""
|
||||
overrides = self.overrides.copy()
|
||||
overrides["rect"] = True # rect batches as default
|
||||
overrides['rect'] = True # rect batches as default
|
||||
overrides.update(kwargs)
|
||||
overrides["mode"] = "val"
|
||||
overrides['mode'] = 'val'
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.data = data or args.data
|
||||
args.task = self.task
|
||||
@ -234,18 +234,18 @@ class YOLO:
|
||||
self._check_is_pytorch_model()
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
if kwargs.get("cfg"):
|
||||
if kwargs.get('cfg'):
|
||||
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
|
||||
overrides = yaml_load(check_yaml(kwargs["cfg"]), append_filename=True)
|
||||
overrides["task"] = self.task
|
||||
overrides["mode"] = "train"
|
||||
if not overrides.get("data"):
|
||||
overrides = yaml_load(check_yaml(kwargs['cfg']), append_filename=True)
|
||||
overrides['task'] = self.task
|
||||
overrides['mode'] = 'train'
|
||||
if not overrides.get('data'):
|
||||
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
|
||||
if overrides.get("resume"):
|
||||
overrides["resume"] = self.ckpt_path
|
||||
if overrides.get('resume'):
|
||||
overrides['resume'] = self.ckpt_path
|
||||
|
||||
self.trainer = self.TrainerClass(overrides=overrides)
|
||||
if not overrides.get("resume"): # manually set model only if not resuming
|
||||
if not overrides.get('resume'): # manually set model only if not resuming
|
||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||
self.model = self.trainer.model
|
||||
self.trainer.train()
|
||||
@ -267,9 +267,9 @@ class YOLO:
|
||||
|
||||
def _assign_ops_from_task(self):
|
||||
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[self.task]
|
||||
trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
|
||||
validator_class = eval(val_lit.replace("TYPE", f"{self.type}"))
|
||||
predictor_class = eval(pred_lit.replace("TYPE", f"{self.type}"))
|
||||
trainer_class = eval(train_lit.replace('TYPE', f'{self.type}'))
|
||||
validator_class = eval(val_lit.replace('TYPE', f'{self.type}'))
|
||||
predictor_class = eval(pred_lit.replace('TYPE', f'{self.type}'))
|
||||
return model_class, trainer_class, validator_class, predictor_class
|
||||
|
||||
@property
|
||||
@ -292,7 +292,7 @@ class YOLO:
|
||||
Returns metrics if computed
|
||||
"""
|
||||
if not self.metrics_data:
|
||||
LOGGER.info("No metrics data found! Run training or validation operation first.")
|
||||
LOGGER.info('No metrics data found! Run training or validation operation first.')
|
||||
|
||||
return self.metrics_data
|
||||
|
||||
|
@ -72,7 +72,7 @@ class BasePredictor:
|
||||
"""
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
name = self.args.name or f"{self.args.mode}"
|
||||
name = self.args.name or f'{self.args.mode}'
|
||||
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
|
||||
if self.args.conf is None:
|
||||
self.args.conf = 0.25 # default conf=0.25
|
||||
@ -97,10 +97,10 @@ class BasePredictor:
|
||||
pass
|
||||
|
||||
def get_annotator(self, img):
|
||||
raise NotImplementedError("get_annotator function needs to be implemented")
|
||||
raise NotImplementedError('get_annotator function needs to be implemented')
|
||||
|
||||
def write_results(self, results, batch, print_string):
|
||||
raise NotImplementedError("print_results function needs to be implemented")
|
||||
raise NotImplementedError('print_results function needs to be implemented')
|
||||
|
||||
def postprocess(self, preds, img, orig_img):
|
||||
return preds
|
||||
@ -135,7 +135,7 @@ class BasePredictor:
|
||||
|
||||
def stream_inference(self, source=None, model=None):
|
||||
if self.args.verbose:
|
||||
LOGGER.info("")
|
||||
LOGGER.info('')
|
||||
|
||||
# setup model
|
||||
if not self.model:
|
||||
@ -152,9 +152,9 @@ class BasePredictor:
|
||||
self.done_warmup = True
|
||||
|
||||
self.seen, self.windows, self.dt, self.batch = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()), None
|
||||
self.run_callbacks("on_predict_start")
|
||||
self.run_callbacks('on_predict_start')
|
||||
for batch in self.dataset:
|
||||
self.run_callbacks("on_predict_batch_start")
|
||||
self.run_callbacks('on_predict_batch_start')
|
||||
self.batch = batch
|
||||
path, im, im0s, vid_cap, s = batch
|
||||
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
|
||||
@ -170,7 +170,7 @@ class BasePredictor:
|
||||
# postprocess
|
||||
with self.dt[2]:
|
||||
self.results = self.postprocess(preds, im, im0s)
|
||||
self.run_callbacks("on_predict_postprocess_end")
|
||||
self.run_callbacks('on_predict_postprocess_end')
|
||||
|
||||
# visualize, save, write results
|
||||
for i in range(len(im)):
|
||||
@ -186,7 +186,7 @@ class BasePredictor:
|
||||
|
||||
if self.args.save:
|
||||
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
||||
self.run_callbacks("on_predict_batch_end")
|
||||
self.run_callbacks('on_predict_batch_end')
|
||||
yield from self.results
|
||||
|
||||
# Print time (inference-only)
|
||||
@ -207,7 +207,7 @@ class BasePredictor:
|
||||
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
||||
|
||||
self.run_callbacks("on_predict_end")
|
||||
self.run_callbacks('on_predict_end')
|
||||
|
||||
def setup_model(self, model):
|
||||
device = select_device(self.args.device)
|
||||
|
@ -36,7 +36,7 @@ class Results:
|
||||
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
|
||||
self.probs = probs if probs is not None else None
|
||||
self.names = names
|
||||
self.comp = ["boxes", "masks", "probs"]
|
||||
self.comp = ['boxes', 'masks', 'probs']
|
||||
|
||||
def pandas(self):
|
||||
pass
|
||||
@ -97,7 +97,7 @@ class Results:
|
||||
return len(getattr(self, item))
|
||||
|
||||
def __str__(self):
|
||||
str_out = ""
|
||||
str_out = ''
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
@ -105,7 +105,7 @@ class Results:
|
||||
return str_out
|
||||
|
||||
def __repr__(self):
|
||||
str_out = ""
|
||||
str_out = ''
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
@ -187,7 +187,7 @@ class Boxes:
|
||||
if boxes.ndim == 1:
|
||||
boxes = boxes[None, :]
|
||||
n = boxes.shape[-1]
|
||||
assert n in {6, 7}, f"expected `n` in [6, 7], but got {n}" # xyxy, (track_id), conf, cls
|
||||
assert n in {6, 7}, f'expected `n` in [6, 7], but got {n}' # xyxy, (track_id), conf, cls
|
||||
# TODO
|
||||
self.is_track = n == 7
|
||||
self.boxes = boxes
|
||||
@ -268,8 +268,8 @@ class Boxes:
|
||||
return self.boxes.__str__()
|
||||
|
||||
def __repr__(self):
|
||||
return (f"Ultralytics YOLO {self.__class__} masks\n" + f"type: {type(self.boxes)}\n" +
|
||||
f"shape: {self.boxes.shape}\n" + f"dtype: {self.boxes.dtype}\n + {self.boxes.__repr__()}")
|
||||
return (f'Ultralytics YOLO {self.__class__} masks\n' + f'type: {type(self.boxes)}\n' +
|
||||
f'shape: {self.boxes.shape}\n' + f'dtype: {self.boxes.dtype}\n + {self.boxes.__repr__()}')
|
||||
|
||||
def __getitem__(self, idx):
|
||||
boxes = self.boxes[idx]
|
||||
@ -353,8 +353,8 @@ class Masks:
|
||||
return self.masks.__str__()
|
||||
|
||||
def __repr__(self):
|
||||
return (f"Ultralytics YOLO {self.__class__} masks\n" + f"type: {type(self.masks)}\n" +
|
||||
f"shape: {self.masks.shape}\n" + f"dtype: {self.masks.dtype}\n + {self.masks.__repr__()}")
|
||||
return (f'Ultralytics YOLO {self.__class__} masks\n' + f'type: {type(self.masks)}\n' +
|
||||
f'shape: {self.masks.shape}\n' + f'dtype: {self.masks.dtype}\n + {self.masks.__repr__()}')
|
||||
|
||||
def __getitem__(self, idx):
|
||||
masks = self.masks[idx]
|
||||
@ -374,19 +374,19 @@ class Masks:
|
||||
""")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
# test examples
|
||||
results = Results(boxes=torch.randn((2, 6)), masks=torch.randn((2, 160, 160)), orig_shape=[640, 640])
|
||||
results = results.cuda()
|
||||
print("--cuda--pass--")
|
||||
print('--cuda--pass--')
|
||||
results = results.cpu()
|
||||
print("--cpu--pass--")
|
||||
results = results.to("cuda:0")
|
||||
print("--to-cuda--pass--")
|
||||
results = results.to("cpu")
|
||||
print("--to-cpu--pass--")
|
||||
print('--cpu--pass--')
|
||||
results = results.to('cuda:0')
|
||||
print('--to-cuda--pass--')
|
||||
results = results.to('cpu')
|
||||
print('--to-cpu--pass--')
|
||||
results = results.numpy()
|
||||
print("--numpy--pass--")
|
||||
print('--numpy--pass--')
|
||||
# box = Boxes(boxes=torch.randn((2, 6)), orig_shape=[5, 5])
|
||||
# box = box.cuda()
|
||||
# box = box.cpu()
|
||||
|
@ -90,7 +90,7 @@ class BaseTrainer:
|
||||
|
||||
# Dirs
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
name = self.args.name or f"{self.args.mode}"
|
||||
name = self.args.name or f'{self.args.mode}'
|
||||
if hasattr(self.args, 'save_dir'):
|
||||
self.save_dir = Path(self.args.save_dir)
|
||||
else:
|
||||
@ -121,7 +121,7 @@ class BaseTrainer:
|
||||
try:
|
||||
if self.args.task == 'classify':
|
||||
self.data = check_cls_dataset(self.args.data)
|
||||
elif self.args.data.endswith(".yaml") or self.args.task in ('detect', 'segment'):
|
||||
elif self.args.data.endswith('.yaml') or self.args.task in ('detect', 'segment'):
|
||||
self.data = check_det_dataset(self.args.data)
|
||||
if 'yaml_file' in self.data:
|
||||
self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
|
||||
@ -175,7 +175,7 @@ class BaseTrainer:
|
||||
world_size = 0
|
||||
|
||||
# Run subprocess if DDP training, else train normally
|
||||
if world_size > 1 and "LOCAL_RANK" not in os.environ:
|
||||
if world_size > 1 and 'LOCAL_RANK' not in os.environ:
|
||||
cmd, file = generate_ddp_command(world_size, self) # security vulnerability in Snyk scans
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
@ -191,15 +191,15 @@ class BaseTrainer:
|
||||
# os.environ['MASTER_PORT'] = '9020'
|
||||
torch.cuda.set_device(rank)
|
||||
self.device = torch.device('cuda', rank)
|
||||
self.console.info(f"DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}")
|
||||
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
|
||||
self.console.info(f'DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||
dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo', rank=rank, world_size=world_size)
|
||||
|
||||
def _setup_train(self, rank, world_size):
|
||||
"""
|
||||
Builds dataloaders and optimizer on correct rank process.
|
||||
"""
|
||||
# model
|
||||
self.run_callbacks("on_pretrain_routine_start")
|
||||
self.run_callbacks('on_pretrain_routine_start')
|
||||
ckpt = self.setup_model()
|
||||
self.model = self.model.to(self.device)
|
||||
self.set_model_attributes()
|
||||
@ -234,16 +234,16 @@ class BaseTrainer:
|
||||
|
||||
# dataloaders
|
||||
batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size
|
||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode="train")
|
||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode='train')
|
||||
if rank in {0, -1}:
|
||||
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
|
||||
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
|
||||
self.validator = self.get_validator()
|
||||
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
|
||||
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
|
||||
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
||||
self.ema = ModelEMA(self.model)
|
||||
self.resume_training(ckpt)
|
||||
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
||||
self.run_callbacks("on_pretrain_routine_end")
|
||||
self.run_callbacks('on_pretrain_routine_end')
|
||||
|
||||
def _do_train(self, rank=-1, world_size=1):
|
||||
if world_size > 1:
|
||||
@ -257,24 +257,24 @@ class BaseTrainer:
|
||||
nb = len(self.train_loader) # number of batches
|
||||
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
|
||||
last_opt_step = -1
|
||||
self.run_callbacks("on_train_start")
|
||||
self.log(f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
|
||||
self.run_callbacks('on_train_start')
|
||||
self.log(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
||||
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
||||
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
||||
f"Starting training for {self.epochs} epochs...")
|
||||
f'Starting training for {self.epochs} epochs...')
|
||||
if self.args.close_mosaic:
|
||||
base_idx = (self.epochs - self.args.close_mosaic) * nb
|
||||
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
|
||||
for epoch in range(self.start_epoch, self.epochs):
|
||||
self.epoch = epoch
|
||||
self.run_callbacks("on_train_epoch_start")
|
||||
self.run_callbacks('on_train_epoch_start')
|
||||
self.model.train()
|
||||
if rank != -1:
|
||||
self.train_loader.sampler.set_epoch(epoch)
|
||||
pbar = enumerate(self.train_loader)
|
||||
# Update dataloader attributes (optional)
|
||||
if epoch == (self.epochs - self.args.close_mosaic):
|
||||
self.console.info("Closing dataloader mosaic")
|
||||
self.console.info('Closing dataloader mosaic')
|
||||
if hasattr(self.train_loader.dataset, 'mosaic'):
|
||||
self.train_loader.dataset.mosaic = False
|
||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||
@ -286,7 +286,7 @@ class BaseTrainer:
|
||||
self.tloss = None
|
||||
self.optimizer.zero_grad()
|
||||
for i, batch in pbar:
|
||||
self.run_callbacks("on_train_batch_start")
|
||||
self.run_callbacks('on_train_batch_start')
|
||||
# Warmup
|
||||
ni = i + nb * epoch
|
||||
if ni <= nw:
|
||||
@ -302,7 +302,7 @@ class BaseTrainer:
|
||||
# Forward
|
||||
with torch.cuda.amp.autocast(self.amp):
|
||||
batch = self.preprocess_batch(batch)
|
||||
preds = self.model(batch["img"])
|
||||
preds = self.model(batch['img'])
|
||||
self.loss, self.loss_items = self.criterion(preds, batch)
|
||||
if rank != -1:
|
||||
self.loss *= world_size
|
||||
@ -324,17 +324,17 @@ class BaseTrainer:
|
||||
if rank in {-1, 0}:
|
||||
pbar.set_description(
|
||||
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
|
||||
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]))
|
||||
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1]))
|
||||
self.run_callbacks('on_batch_end')
|
||||
if self.args.plots and ni in self.plot_idx:
|
||||
self.plot_training_samples(batch, ni)
|
||||
|
||||
self.run_callbacks("on_train_batch_end")
|
||||
self.run_callbacks('on_train_batch_end')
|
||||
|
||||
self.lr = {f"lr/pg{ir}": x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
||||
self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
||||
|
||||
self.scheduler.step()
|
||||
self.run_callbacks("on_train_epoch_end")
|
||||
self.run_callbacks('on_train_epoch_end')
|
||||
|
||||
if rank in {-1, 0}:
|
||||
|
||||
@ -355,7 +355,7 @@ class BaseTrainer:
|
||||
tnow = time.time()
|
||||
self.epoch_time = tnow - self.epoch_time_start
|
||||
self.epoch_time_start = tnow
|
||||
self.run_callbacks("on_fit_epoch_end")
|
||||
self.run_callbacks('on_fit_epoch_end')
|
||||
|
||||
# Early Stopping
|
||||
if RANK != -1: # if DDP training
|
||||
@ -402,7 +402,7 @@ class BaseTrainer:
|
||||
"""
|
||||
Get train, val path from data dict if it exists. Returns None if data format is not recognized.
|
||||
"""
|
||||
return data["train"], data.get("val") or data.get("test")
|
||||
return data['train'], data.get('val') or data.get('test')
|
||||
|
||||
def setup_model(self):
|
||||
"""
|
||||
@ -413,9 +413,9 @@ class BaseTrainer:
|
||||
|
||||
model, weights = self.model, None
|
||||
ckpt = None
|
||||
if str(model).endswith(".pt"):
|
||||
if str(model).endswith('.pt'):
|
||||
weights, ckpt = attempt_load_one_weight(model)
|
||||
cfg = ckpt["model"].yaml
|
||||
cfg = ckpt['model'].yaml
|
||||
else:
|
||||
cfg = model
|
||||
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
|
||||
@ -441,7 +441,7 @@ class BaseTrainer:
|
||||
Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
|
||||
"""
|
||||
metrics = self.validator(self)
|
||||
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
||||
fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
||||
if not self.best_fitness or self.best_fitness < fitness:
|
||||
self.best_fitness = fitness
|
||||
return metrics, fitness
|
||||
@ -462,38 +462,38 @@ class BaseTrainer:
|
||||
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||
|
||||
def get_validator(self):
|
||||
raise NotImplementedError("get_validator function not implemented in trainer")
|
||||
raise NotImplementedError('get_validator function not implemented in trainer')
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
||||
"""
|
||||
Returns dataloader derived from torch.data.Dataloader.
|
||||
"""
|
||||
raise NotImplementedError("get_dataloader function not implemented in trainer")
|
||||
raise NotImplementedError('get_dataloader function not implemented in trainer')
|
||||
|
||||
def criterion(self, preds, batch):
|
||||
"""
|
||||
Returns loss and individual loss items as Tensor.
|
||||
"""
|
||||
raise NotImplementedError("criterion function not implemented in trainer")
|
||||
raise NotImplementedError('criterion function not implemented in trainer')
|
||||
|
||||
def label_loss_items(self, loss_items=None, prefix="train"):
|
||||
def label_loss_items(self, loss_items=None, prefix='train'):
|
||||
"""
|
||||
Returns a loss dict with labelled training loss items tensor
|
||||
"""
|
||||
# Not needed for classification but necessary for segmentation & detection
|
||||
return {"loss": loss_items} if loss_items is not None else ["loss"]
|
||||
return {'loss': loss_items} if loss_items is not None else ['loss']
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""
|
||||
To set or update model parameters before training.
|
||||
"""
|
||||
self.model.names = self.data["names"]
|
||||
self.model.names = self.data['names']
|
||||
|
||||
def build_targets(self, preds, targets):
|
||||
pass
|
||||
|
||||
def progress_string(self):
|
||||
return ""
|
||||
return ''
|
||||
|
||||
# TODO: may need to put these following functions into callback
|
||||
def plot_training_samples(self, batch, ni):
|
||||
@ -529,7 +529,7 @@ class BaseTrainer:
|
||||
self.args = get_cfg(attempt_load_weights(last).args)
|
||||
self.args.model, resume = str(last), True # reinstate
|
||||
except Exception as e:
|
||||
raise FileNotFoundError("Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
|
||||
raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '
|
||||
"i.e. 'yolo train resume model=path/to/last.pt'") from e
|
||||
self.resume = resume
|
||||
|
||||
@ -557,7 +557,7 @@ class BaseTrainer:
|
||||
self.best_fitness = best_fitness
|
||||
self.start_epoch = start_epoch
|
||||
if start_epoch > (self.epochs - self.args.close_mosaic):
|
||||
self.console.info("Closing dataloader mosaic")
|
||||
self.console.info('Closing dataloader mosaic')
|
||||
if hasattr(self.train_loader.dataset, 'mosaic'):
|
||||
self.train_loader.dataset.mosaic = False
|
||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||
@ -602,5 +602,5 @@ class BaseTrainer:
|
||||
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
|
||||
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
|
||||
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
|
||||
f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias")
|
||||
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias')
|
||||
return optimizer
|
||||
|
@ -62,7 +62,7 @@ class BaseValidator:
|
||||
self.jdict = None
|
||||
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
name = self.args.name or f"{self.args.mode}"
|
||||
name = self.args.name or f'{self.args.mode}'
|
||||
self.save_dir = save_dir or increment_path(Path(project) / name,
|
||||
exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)
|
||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
@ -92,7 +92,7 @@ class BaseValidator:
|
||||
else:
|
||||
callbacks.add_integration_callbacks(self)
|
||||
self.run_callbacks('on_val_start')
|
||||
assert model is not None, "Either trainer or model is needed for validation"
|
||||
assert model is not None, 'Either trainer or model is needed for validation'
|
||||
self.device = select_device(self.args.device, self.args.batch)
|
||||
self.args.half &= self.device.type != 'cpu'
|
||||
model = AutoBackend(model, device=self.device, dnn=self.args.dnn, data=self.args.data, fp16=self.args.half)
|
||||
@ -108,7 +108,7 @@ class BaseValidator:
|
||||
self.logger.info(
|
||||
f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
||||
|
||||
if isinstance(self.args.data, str) and self.args.data.endswith(".yaml"):
|
||||
if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'):
|
||||
self.data = check_det_dataset(self.args.data)
|
||||
elif self.args.task == 'classify':
|
||||
self.data = check_cls_dataset(self.args.data)
|
||||
@ -142,7 +142,7 @@ class BaseValidator:
|
||||
|
||||
# inference
|
||||
with dt[1]:
|
||||
preds = model(batch["img"])
|
||||
preds = model(batch['img'])
|
||||
|
||||
# loss
|
||||
with dt[2]:
|
||||
@ -166,14 +166,14 @@ class BaseValidator:
|
||||
self.run_callbacks('on_val_end')
|
||||
if self.training:
|
||||
model.float()
|
||||
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
|
||||
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
|
||||
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
||||
else:
|
||||
self.logger.info('Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image' %
|
||||
self.speed)
|
||||
if self.args.save_json and self.jdict:
|
||||
with open(str(self.save_dir / "predictions.json"), 'w') as f:
|
||||
self.logger.info(f"Saving {f.name}...")
|
||||
with open(str(self.save_dir / 'predictions.json'), 'w') as f:
|
||||
self.logger.info(f'Saving {f.name}...')
|
||||
json.dump(self.jdict, f) # flatten and save
|
||||
stats = self.eval_json(stats) # update stats
|
||||
return stats
|
||||
@ -183,7 +183,7 @@ class BaseValidator:
|
||||
callback(self)
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
raise NotImplementedError("get_dataloader function not implemented for this validator")
|
||||
raise NotImplementedError('get_dataloader function not implemented for this validator')
|
||||
|
||||
def preprocess(self, batch):
|
||||
return batch
|
||||
|
Reference in New Issue
Block a user