Update .pre-commit-config.yaml (#1026)

This commit is contained in:
Glenn Jocher
2023-02-17 22:26:40 +01:00
committed by GitHub
parent 9047d737f4
commit edd3ff1669
76 changed files with 928 additions and 935 deletions

View File

@ -144,7 +144,7 @@ class Exporter:
@smart_inference_mode()
def __call__(self, model=None):
self.run_callbacks("on_export_start")
self.run_callbacks('on_export_start')
t = time.time()
format = self.args.format.lower() # to lowercase
if format in {'tensorrt', 'trt'}: # engine aliases
@ -207,7 +207,7 @@ class Exporter:
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
self.metadata = {
'description': f"Ultralytics {self.pretty_name} model trained on {self.args.data}",
'description': f'Ultralytics {self.pretty_name} model trained on {self.args.data}',
'author': 'Ultralytics',
'license': 'GPL-3.0 https://ultralytics.com/license',
'version': __version__,
@ -215,7 +215,7 @@ class Exporter:
'names': model.names} # model metadata
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and "
f"output shape(s) {self.output_shape} ({file_size(file):.1f} MB)")
f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)')
# Exports
f = [''] * len(fmts) # exported filenames
@ -259,15 +259,15 @@ class Exporter:
s = '' if square else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " \
f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(' ', '')
data = f"data={self.args.data}" if model.task == 'segment' and format == 'pb' else ''
data = f'data={self.args.data}' if model.task == 'segment' and format == 'pb' else ''
LOGGER.info(
f'\nExport complete ({time.time() - t:.1f}s)'
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
f"\nPredict: yolo task={model.task} mode=predict model={f} imgsz={imgsz} {data}"
f"\nValidate: yolo task={model.task} mode=val model={f} imgsz={imgsz} data={self.args.data} {s}"
f"\nVisualize: https://netron.app")
f'\nPredict: yolo task={model.task} mode=predict model={f} imgsz={imgsz} {data}'
f'\nValidate: yolo task={model.task} mode=val model={f} imgsz={imgsz} data={self.args.data} {s}'
f'\nVisualize: https://netron.app')
self.run_callbacks("on_export_end")
self.run_callbacks('on_export_end')
return f # return list of exported files/dirs
@try_export
@ -277,7 +277,7 @@ class Exporter:
f = self.file.with_suffix('.torchscript')
ts = torch.jit.trace(self.model, self.im, strict=False)
d = {"shape": self.im.shape, "stride": int(max(self.model.stride)), "names": self.model.names}
d = {'shape': self.im.shape, 'stride': int(max(self.model.stride)), 'names': self.model.names}
extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
LOGGER.info(f'{prefix} optimizing for mobile...')
@ -354,7 +354,7 @@ class Exporter:
ov_model = mo.convert_model(f_onnx,
model_name=self.pretty_name,
framework="onnx",
framework='onnx',
compress_to_fp16=self.args.half) # export
ov.serialize(ov_model, f_ov) # save
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
@ -471,7 +471,7 @@ class Exporter:
if self.args.dynamic:
shape = self.im.shape
if shape[0] <= 1:
LOGGER.warning(f"{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument")
LOGGER.warning(f'{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument')
profile = builder.create_optimization_profile()
for inp in inputs:
profile.set_shape(inp.name, (1, *shape[1:]), (max(1, shape[0] // 2), *shape[1:]), shape)
@ -509,8 +509,8 @@ class Exporter:
except ImportError:
check_requirements(f"tensorflow{'' if CUDA else '-macos' if MACOS else '-cpu' if LINUX else ''}")
import tensorflow as tf # noqa
check_requirements(("onnx", "onnx2tf", "sng4onnx", "onnxsim", "onnx_graphsurgeon", "tflite_support"),
cmds="--extra-index-url https://pypi.ngc.nvidia.com")
check_requirements(('onnx', 'onnx2tf', 'sng4onnx', 'onnxsim', 'onnx_graphsurgeon', 'tflite_support'),
cmds='--extra-index-url https://pypi.ngc.nvidia.com')
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
f = str(self.file).replace(self.file.suffix, '_saved_model')
@ -632,7 +632,7 @@ class Exporter:
converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
tflite_model = converter.convert()
open(f, "wb").write(tflite_model)
open(f, 'wb').write(tflite_model)
return f, None
@try_export
@ -656,7 +656,7 @@ class Exporter:
LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
f = str(tflite_model).replace('.tflite', '_edgetpu.tflite') # Edge TPU model
cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {self.file.parent} {tflite_model}"
cmd = f'edgetpu_compiler -s -d -k 10 --out_dir {self.file.parent} {tflite_model}'
subprocess.run(cmd.split(), check=True)
self._add_tflite_metadata(f)
return f, None
@ -707,8 +707,8 @@ class Exporter:
# Creates input info.
input_meta = _metadata_fb.TensorMetadataT()
input_meta.name = "image"
input_meta.description = "Input image to be detected."
input_meta.name = 'image'
input_meta.description = 'Input image to be detected.'
input_meta.content = _metadata_fb.ContentT()
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB
@ -716,8 +716,8 @@ class Exporter:
# Creates output info.
output_meta = _metadata_fb.TensorMetadataT()
output_meta.name = "output"
output_meta.description = "Coordinates of detected objects, class labels, and confidence score."
output_meta.name = 'output'
output_meta.description = 'Coordinates of detected objects, class labels, and confidence score.'
# Label file
tmp_file = Path('/tmp/meta.txt')
@ -868,8 +868,8 @@ class Exporter:
def export(cfg=DEFAULT_CFG):
cfg.model = cfg.model or "yolov8n.yaml"
cfg.format = cfg.format or "torchscript"
cfg.model = cfg.model or 'yolov8n.yaml'
cfg.format = cfg.format or 'torchscript'
# exporter = Exporter(cfg)
#
@ -888,7 +888,7 @@ def export(cfg=DEFAULT_CFG):
model.export(**vars(cfg))
if __name__ == "__main__":
if __name__ == '__main__':
"""
CLI:
yolo mode=export model=yolov8n.yaml format=onnx

View File

@ -16,13 +16,13 @@ from ultralytics.yolo.utils.torch_utils import smart_inference_mode
# Map head to model, trainer, validator, and predictor classes
MODEL_MAP = {
"classify": [
'classify': [
ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator',
'yolo.TYPE.classify.ClassificationPredictor'],
"detect": [
'detect': [
DetectionModel, 'yolo.TYPE.detect.DetectionTrainer', 'yolo.TYPE.detect.DetectionValidator',
'yolo.TYPE.detect.DetectionPredictor'],
"segment": [
'segment': [
SegmentationModel, 'yolo.TYPE.segment.SegmentationTrainer', 'yolo.TYPE.segment.SegmentationValidator',
'yolo.TYPE.segment.SegmentationPredictor']}
@ -34,7 +34,7 @@ class YOLO:
A python interface which emulates a model-like behaviour by wrapping trainers.
"""
def __init__(self, model='yolov8n.pt', type="v8") -> None:
def __init__(self, model='yolov8n.pt', type='v8') -> None:
"""
Initializes the YOLO object.
@ -94,7 +94,7 @@ class YOLO:
suffix = Path(weights).suffix
if suffix == '.pt':
self.model, self.ckpt = attempt_load_one_weight(weights)
self.task = self.model.args["task"]
self.task = self.model.args['task']
self.overrides = self.model.args
self._reset_ckpt_args(self.overrides)
else:
@ -111,7 +111,7 @@ class YOLO:
"""
if not isinstance(self.model, nn.Module):
raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
f"PyTorch models can be used to train, val, predict and export, i.e. "
f'PyTorch models can be used to train, val, predict and export, i.e. '
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
@ -155,11 +155,11 @@ class YOLO:
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
"""
overrides = self.overrides.copy()
overrides["conf"] = 0.25
overrides['conf'] = 0.25
overrides.update(kwargs)
overrides["mode"] = kwargs.get("mode", "predict")
assert overrides["mode"] in ['track', 'predict']
overrides["save"] = kwargs.get("save", False) # not save files by default
overrides['mode'] = kwargs.get('mode', 'predict')
assert overrides['mode'] in ['track', 'predict']
overrides['save'] = kwargs.get('save', False) # not save files by default
if not self.predictor:
self.predictor = self.PredictorClass(overrides=overrides)
self.predictor.setup_model(model=self.model)
@ -173,7 +173,7 @@ class YOLO:
from ultralytics.tracker.track import register_tracker
register_tracker(self)
# bytetrack-based method needs low confidence predictions as input
conf = kwargs.get("conf") or 0.1
conf = kwargs.get('conf') or 0.1
kwargs['conf'] = conf
kwargs['mode'] = 'track'
return self.predict(source=source, stream=stream, **kwargs)
@ -188,9 +188,9 @@ class YOLO:
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
"""
overrides = self.overrides.copy()
overrides["rect"] = True # rect batches as default
overrides['rect'] = True # rect batches as default
overrides.update(kwargs)
overrides["mode"] = "val"
overrides['mode'] = 'val'
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.data = data or args.data
args.task = self.task
@ -234,18 +234,18 @@ class YOLO:
self._check_is_pytorch_model()
overrides = self.overrides.copy()
overrides.update(kwargs)
if kwargs.get("cfg"):
if kwargs.get('cfg'):
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
overrides = yaml_load(check_yaml(kwargs["cfg"]), append_filename=True)
overrides["task"] = self.task
overrides["mode"] = "train"
if not overrides.get("data"):
overrides = yaml_load(check_yaml(kwargs['cfg']), append_filename=True)
overrides['task'] = self.task
overrides['mode'] = 'train'
if not overrides.get('data'):
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
if overrides.get("resume"):
overrides["resume"] = self.ckpt_path
if overrides.get('resume'):
overrides['resume'] = self.ckpt_path
self.trainer = self.TrainerClass(overrides=overrides)
if not overrides.get("resume"): # manually set model only if not resuming
if not overrides.get('resume'): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
self.trainer.train()
@ -267,9 +267,9 @@ class YOLO:
def _assign_ops_from_task(self):
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[self.task]
trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
validator_class = eval(val_lit.replace("TYPE", f"{self.type}"))
predictor_class = eval(pred_lit.replace("TYPE", f"{self.type}"))
trainer_class = eval(train_lit.replace('TYPE', f'{self.type}'))
validator_class = eval(val_lit.replace('TYPE', f'{self.type}'))
predictor_class = eval(pred_lit.replace('TYPE', f'{self.type}'))
return model_class, trainer_class, validator_class, predictor_class
@property
@ -292,7 +292,7 @@ class YOLO:
Returns metrics if computed
"""
if not self.metrics_data:
LOGGER.info("No metrics data found! Run training or validation operation first.")
LOGGER.info('No metrics data found! Run training or validation operation first.')
return self.metrics_data

View File

@ -72,7 +72,7 @@ class BasePredictor:
"""
self.args = get_cfg(cfg, overrides)
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f"{self.args.mode}"
name = self.args.name or f'{self.args.mode}'
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
if self.args.conf is None:
self.args.conf = 0.25 # default conf=0.25
@ -97,10 +97,10 @@ class BasePredictor:
pass
def get_annotator(self, img):
raise NotImplementedError("get_annotator function needs to be implemented")
raise NotImplementedError('get_annotator function needs to be implemented')
def write_results(self, results, batch, print_string):
raise NotImplementedError("print_results function needs to be implemented")
raise NotImplementedError('print_results function needs to be implemented')
def postprocess(self, preds, img, orig_img):
return preds
@ -135,7 +135,7 @@ class BasePredictor:
def stream_inference(self, source=None, model=None):
if self.args.verbose:
LOGGER.info("")
LOGGER.info('')
# setup model
if not self.model:
@ -152,9 +152,9 @@ class BasePredictor:
self.done_warmup = True
self.seen, self.windows, self.dt, self.batch = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()), None
self.run_callbacks("on_predict_start")
self.run_callbacks('on_predict_start')
for batch in self.dataset:
self.run_callbacks("on_predict_batch_start")
self.run_callbacks('on_predict_batch_start')
self.batch = batch
path, im, im0s, vid_cap, s = batch
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
@ -170,7 +170,7 @@ class BasePredictor:
# postprocess
with self.dt[2]:
self.results = self.postprocess(preds, im, im0s)
self.run_callbacks("on_predict_postprocess_end")
self.run_callbacks('on_predict_postprocess_end')
# visualize, save, write results
for i in range(len(im)):
@ -186,7 +186,7 @@ class BasePredictor:
if self.args.save:
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
self.run_callbacks("on_predict_batch_end")
self.run_callbacks('on_predict_batch_end')
yield from self.results
# Print time (inference-only)
@ -207,7 +207,7 @@ class BasePredictor:
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
self.run_callbacks("on_predict_end")
self.run_callbacks('on_predict_end')
def setup_model(self, model):
device = select_device(self.args.device)

View File

@ -36,7 +36,7 @@ class Results:
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
self.probs = probs if probs is not None else None
self.names = names
self.comp = ["boxes", "masks", "probs"]
self.comp = ['boxes', 'masks', 'probs']
def pandas(self):
pass
@ -97,7 +97,7 @@ class Results:
return len(getattr(self, item))
def __str__(self):
str_out = ""
str_out = ''
for item in self.comp:
if getattr(self, item) is None:
continue
@ -105,7 +105,7 @@ class Results:
return str_out
def __repr__(self):
str_out = ""
str_out = ''
for item in self.comp:
if getattr(self, item) is None:
continue
@ -187,7 +187,7 @@ class Boxes:
if boxes.ndim == 1:
boxes = boxes[None, :]
n = boxes.shape[-1]
assert n in {6, 7}, f"expected `n` in [6, 7], but got {n}" # xyxy, (track_id), conf, cls
assert n in {6, 7}, f'expected `n` in [6, 7], but got {n}' # xyxy, (track_id), conf, cls
# TODO
self.is_track = n == 7
self.boxes = boxes
@ -268,8 +268,8 @@ class Boxes:
return self.boxes.__str__()
def __repr__(self):
return (f"Ultralytics YOLO {self.__class__} masks\n" + f"type: {type(self.boxes)}\n" +
f"shape: {self.boxes.shape}\n" + f"dtype: {self.boxes.dtype}\n + {self.boxes.__repr__()}")
return (f'Ultralytics YOLO {self.__class__} masks\n' + f'type: {type(self.boxes)}\n' +
f'shape: {self.boxes.shape}\n' + f'dtype: {self.boxes.dtype}\n + {self.boxes.__repr__()}')
def __getitem__(self, idx):
boxes = self.boxes[idx]
@ -353,8 +353,8 @@ class Masks:
return self.masks.__str__()
def __repr__(self):
return (f"Ultralytics YOLO {self.__class__} masks\n" + f"type: {type(self.masks)}\n" +
f"shape: {self.masks.shape}\n" + f"dtype: {self.masks.dtype}\n + {self.masks.__repr__()}")
return (f'Ultralytics YOLO {self.__class__} masks\n' + f'type: {type(self.masks)}\n' +
f'shape: {self.masks.shape}\n' + f'dtype: {self.masks.dtype}\n + {self.masks.__repr__()}')
def __getitem__(self, idx):
masks = self.masks[idx]
@ -374,19 +374,19 @@ class Masks:
""")
if __name__ == "__main__":
if __name__ == '__main__':
# test examples
results = Results(boxes=torch.randn((2, 6)), masks=torch.randn((2, 160, 160)), orig_shape=[640, 640])
results = results.cuda()
print("--cuda--pass--")
print('--cuda--pass--')
results = results.cpu()
print("--cpu--pass--")
results = results.to("cuda:0")
print("--to-cuda--pass--")
results = results.to("cpu")
print("--to-cpu--pass--")
print('--cpu--pass--')
results = results.to('cuda:0')
print('--to-cuda--pass--')
results = results.to('cpu')
print('--to-cpu--pass--')
results = results.numpy()
print("--numpy--pass--")
print('--numpy--pass--')
# box = Boxes(boxes=torch.randn((2, 6)), orig_shape=[5, 5])
# box = box.cuda()
# box = box.cpu()

View File

@ -90,7 +90,7 @@ class BaseTrainer:
# Dirs
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f"{self.args.mode}"
name = self.args.name or f'{self.args.mode}'
if hasattr(self.args, 'save_dir'):
self.save_dir = Path(self.args.save_dir)
else:
@ -121,7 +121,7 @@ class BaseTrainer:
try:
if self.args.task == 'classify':
self.data = check_cls_dataset(self.args.data)
elif self.args.data.endswith(".yaml") or self.args.task in ('detect', 'segment'):
elif self.args.data.endswith('.yaml') or self.args.task in ('detect', 'segment'):
self.data = check_det_dataset(self.args.data)
if 'yaml_file' in self.data:
self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
@ -175,7 +175,7 @@ class BaseTrainer:
world_size = 0
# Run subprocess if DDP training, else train normally
if world_size > 1 and "LOCAL_RANK" not in os.environ:
if world_size > 1 and 'LOCAL_RANK' not in os.environ:
cmd, file = generate_ddp_command(world_size, self) # security vulnerability in Snyk scans
try:
subprocess.run(cmd, check=True)
@ -191,15 +191,15 @@ class BaseTrainer:
# os.environ['MASTER_PORT'] = '9020'
torch.cuda.set_device(rank)
self.device = torch.device('cuda', rank)
self.console.info(f"DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}")
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
self.console.info(f'DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}')
dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo', rank=rank, world_size=world_size)
def _setup_train(self, rank, world_size):
"""
Builds dataloaders and optimizer on correct rank process.
"""
# model
self.run_callbacks("on_pretrain_routine_start")
self.run_callbacks('on_pretrain_routine_start')
ckpt = self.setup_model()
self.model = self.model.to(self.device)
self.set_model_attributes()
@ -234,16 +234,16 @@ class BaseTrainer:
# dataloaders
batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode="train")
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode='train')
if rank in {0, -1}:
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
self.validator = self.get_validator()
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
self.ema = ModelEMA(self.model)
self.resume_training(ckpt)
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
self.run_callbacks("on_pretrain_routine_end")
self.run_callbacks('on_pretrain_routine_end')
def _do_train(self, rank=-1, world_size=1):
if world_size > 1:
@ -257,24 +257,24 @@ class BaseTrainer:
nb = len(self.train_loader) # number of batches
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
last_opt_step = -1
self.run_callbacks("on_train_start")
self.log(f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
self.run_callbacks('on_train_start')
self.log(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
f"Logging results to {colorstr('bold', self.save_dir)}\n"
f"Starting training for {self.epochs} epochs...")
f'Starting training for {self.epochs} epochs...')
if self.args.close_mosaic:
base_idx = (self.epochs - self.args.close_mosaic) * nb
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
for epoch in range(self.start_epoch, self.epochs):
self.epoch = epoch
self.run_callbacks("on_train_epoch_start")
self.run_callbacks('on_train_epoch_start')
self.model.train()
if rank != -1:
self.train_loader.sampler.set_epoch(epoch)
pbar = enumerate(self.train_loader)
# Update dataloader attributes (optional)
if epoch == (self.epochs - self.args.close_mosaic):
self.console.info("Closing dataloader mosaic")
self.console.info('Closing dataloader mosaic')
if hasattr(self.train_loader.dataset, 'mosaic'):
self.train_loader.dataset.mosaic = False
if hasattr(self.train_loader.dataset, 'close_mosaic'):
@ -286,7 +286,7 @@ class BaseTrainer:
self.tloss = None
self.optimizer.zero_grad()
for i, batch in pbar:
self.run_callbacks("on_train_batch_start")
self.run_callbacks('on_train_batch_start')
# Warmup
ni = i + nb * epoch
if ni <= nw:
@ -302,7 +302,7 @@ class BaseTrainer:
# Forward
with torch.cuda.amp.autocast(self.amp):
batch = self.preprocess_batch(batch)
preds = self.model(batch["img"])
preds = self.model(batch['img'])
self.loss, self.loss_items = self.criterion(preds, batch)
if rank != -1:
self.loss *= world_size
@ -324,17 +324,17 @@ class BaseTrainer:
if rank in {-1, 0}:
pbar.set_description(
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]))
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1]))
self.run_callbacks('on_batch_end')
if self.args.plots and ni in self.plot_idx:
self.plot_training_samples(batch, ni)
self.run_callbacks("on_train_batch_end")
self.run_callbacks('on_train_batch_end')
self.lr = {f"lr/pg{ir}": x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
self.scheduler.step()
self.run_callbacks("on_train_epoch_end")
self.run_callbacks('on_train_epoch_end')
if rank in {-1, 0}:
@ -355,7 +355,7 @@ class BaseTrainer:
tnow = time.time()
self.epoch_time = tnow - self.epoch_time_start
self.epoch_time_start = tnow
self.run_callbacks("on_fit_epoch_end")
self.run_callbacks('on_fit_epoch_end')
# Early Stopping
if RANK != -1: # if DDP training
@ -402,7 +402,7 @@ class BaseTrainer:
"""
Get train, val path from data dict if it exists. Returns None if data format is not recognized.
"""
return data["train"], data.get("val") or data.get("test")
return data['train'], data.get('val') or data.get('test')
def setup_model(self):
"""
@ -413,9 +413,9 @@ class BaseTrainer:
model, weights = self.model, None
ckpt = None
if str(model).endswith(".pt"):
if str(model).endswith('.pt'):
weights, ckpt = attempt_load_one_weight(model)
cfg = ckpt["model"].yaml
cfg = ckpt['model'].yaml
else:
cfg = model
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
@ -441,7 +441,7 @@ class BaseTrainer:
Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
"""
metrics = self.validator(self)
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
if not self.best_fitness or self.best_fitness < fitness:
self.best_fitness = fitness
return metrics, fitness
@ -462,38 +462,38 @@ class BaseTrainer:
raise NotImplementedError("This task trainer doesn't support loading cfg files")
def get_validator(self):
raise NotImplementedError("get_validator function not implemented in trainer")
raise NotImplementedError('get_validator function not implemented in trainer')
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
"""
Returns dataloader derived from torch.data.Dataloader.
"""
raise NotImplementedError("get_dataloader function not implemented in trainer")
raise NotImplementedError('get_dataloader function not implemented in trainer')
def criterion(self, preds, batch):
"""
Returns loss and individual loss items as Tensor.
"""
raise NotImplementedError("criterion function not implemented in trainer")
raise NotImplementedError('criterion function not implemented in trainer')
def label_loss_items(self, loss_items=None, prefix="train"):
def label_loss_items(self, loss_items=None, prefix='train'):
"""
Returns a loss dict with labelled training loss items tensor
"""
# Not needed for classification but necessary for segmentation & detection
return {"loss": loss_items} if loss_items is not None else ["loss"]
return {'loss': loss_items} if loss_items is not None else ['loss']
def set_model_attributes(self):
"""
To set or update model parameters before training.
"""
self.model.names = self.data["names"]
self.model.names = self.data['names']
def build_targets(self, preds, targets):
pass
def progress_string(self):
return ""
return ''
# TODO: may need to put these following functions into callback
def plot_training_samples(self, batch, ni):
@ -529,7 +529,7 @@ class BaseTrainer:
self.args = get_cfg(attempt_load_weights(last).args)
self.args.model, resume = str(last), True # reinstate
except Exception as e:
raise FileNotFoundError("Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '
"i.e. 'yolo train resume model=path/to/last.pt'") from e
self.resume = resume
@ -557,7 +557,7 @@ class BaseTrainer:
self.best_fitness = best_fitness
self.start_epoch = start_epoch
if start_epoch > (self.epochs - self.args.close_mosaic):
self.console.info("Closing dataloader mosaic")
self.console.info('Closing dataloader mosaic')
if hasattr(self.train_loader.dataset, 'mosaic'):
self.train_loader.dataset.mosaic = False
if hasattr(self.train_loader.dataset, 'close_mosaic'):
@ -602,5 +602,5 @@ class BaseTrainer:
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias")
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias')
return optimizer

View File

@ -62,7 +62,7 @@ class BaseValidator:
self.jdict = None
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f"{self.args.mode}"
name = self.args.name or f'{self.args.mode}'
self.save_dir = save_dir or increment_path(Path(project) / name,
exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
@ -92,7 +92,7 @@ class BaseValidator:
else:
callbacks.add_integration_callbacks(self)
self.run_callbacks('on_val_start')
assert model is not None, "Either trainer or model is needed for validation"
assert model is not None, 'Either trainer or model is needed for validation'
self.device = select_device(self.args.device, self.args.batch)
self.args.half &= self.device.type != 'cpu'
model = AutoBackend(model, device=self.device, dnn=self.args.dnn, data=self.args.data, fp16=self.args.half)
@ -108,7 +108,7 @@ class BaseValidator:
self.logger.info(
f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
if isinstance(self.args.data, str) and self.args.data.endswith(".yaml"):
if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'):
self.data = check_det_dataset(self.args.data)
elif self.args.task == 'classify':
self.data = check_cls_dataset(self.args.data)
@ -142,7 +142,7 @@ class BaseValidator:
# inference
with dt[1]:
preds = model(batch["img"])
preds = model(batch['img'])
# loss
with dt[2]:
@ -166,14 +166,14 @@ class BaseValidator:
self.run_callbacks('on_val_end')
if self.training:
model.float()
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
else:
self.logger.info('Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image' %
self.speed)
if self.args.save_json and self.jdict:
with open(str(self.save_dir / "predictions.json"), 'w') as f:
self.logger.info(f"Saving {f.name}...")
with open(str(self.save_dir / 'predictions.json'), 'w') as f:
self.logger.info(f'Saving {f.name}...')
json.dump(self.jdict, f) # flatten and save
stats = self.eval_json(stats) # update stats
return stats
@ -183,7 +183,7 @@ class BaseValidator:
callback(self)
def get_dataloader(self, dataset_path, batch_size):
raise NotImplementedError("get_dataloader function not implemented for this validator")
raise NotImplementedError('get_dataloader function not implemented for this validator')
def preprocess(self, batch):
return batch