Replace nosave
and noval
with save
and val
(#127)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -104,7 +104,6 @@ class BasePredictor:
|
||||
def setup(self, source=None, model=None):
|
||||
# source
|
||||
source = str(source or self.args.source)
|
||||
self.save_img = not self.args.nosave and not source.endswith('.txt')
|
||||
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
||||
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
|
||||
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
|
||||
@ -168,10 +167,10 @@ class BasePredictor:
|
||||
p = Path(path)
|
||||
s += self.write_results(i, preds, (p, im, im0s))
|
||||
|
||||
if self.args.view_img:
|
||||
if self.args.show:
|
||||
self.show(p)
|
||||
|
||||
if self.save_img:
|
||||
if self.args.save:
|
||||
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
||||
|
||||
# Print time (inference-only)
|
||||
@ -182,7 +181,7 @@ class BasePredictor:
|
||||
LOGGER.info(
|
||||
f'Speed: %.1fms pre-process, %.1fms inference, %.1fms postprocess per image at shape {(1, 3, *self.imgsz)}'
|
||||
% t)
|
||||
if self.args.save_txt or self.save_img:
|
||||
if self.args.save_txt or self.args.save:
|
||||
s = f"\n{len(list(self.save_dir.glob('labels/*.txt')))} labels saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
||||
|
||||
|
@ -244,12 +244,12 @@ class BaseTrainer:
|
||||
for i, batch in pbar:
|
||||
self.trigger_callbacks("on_train_batch_start")
|
||||
|
||||
# update dataloader attributes (optional)
|
||||
# Update dataloader attributes (optional)
|
||||
if epoch == (self.epochs - self.args.close_mosaic) and hasattr(self.train_loader.dataset, 'mosaic'):
|
||||
LOGGER.info("Closing dataloader mosaic")
|
||||
self.train_loader.dataset.mosaic = False
|
||||
|
||||
# warmup
|
||||
# Warmup
|
||||
ni = i + nb * epoch
|
||||
if ni <= nw:
|
||||
xi = [0, nw] # x interp
|
||||
@ -261,7 +261,7 @@ class BaseTrainer:
|
||||
if 'momentum' in x:
|
||||
x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
|
||||
|
||||
# forward
|
||||
# Forward
|
||||
with torch.cuda.amp.autocast(self.amp):
|
||||
batch = self.preprocess_batch(batch)
|
||||
preds = self.model(batch["img"])
|
||||
@ -271,15 +271,15 @@ class BaseTrainer:
|
||||
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
|
||||
else self.loss_items
|
||||
|
||||
# backward
|
||||
# Backward
|
||||
self.scaler.scale(self.loss).backward()
|
||||
|
||||
# optimize - https://pytorch.org/docs/master/notes/amp_examples.html
|
||||
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
|
||||
if ni - last_opt_step >= self.accumulate:
|
||||
self.optimizer_step()
|
||||
last_opt_step = ni
|
||||
|
||||
# log
|
||||
# Log
|
||||
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
|
||||
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
|
||||
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
||||
@ -298,17 +298,17 @@ class BaseTrainer:
|
||||
self.trigger_callbacks("on_train_epoch_end")
|
||||
|
||||
if rank in {-1, 0}:
|
||||
# validation
|
||||
# Validation
|
||||
self.trigger_callbacks('on_val_start')
|
||||
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
|
||||
final_epoch = (epoch + 1 == self.epochs)
|
||||
if not self.args.noval or final_epoch:
|
||||
if self.args.val or final_epoch:
|
||||
self.metrics, self.fitness = self.validate()
|
||||
self.trigger_callbacks('on_val_end')
|
||||
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **lr})
|
||||
|
||||
# save model
|
||||
if (not self.args.nosave) or (epoch + 1 == self.epochs):
|
||||
# Save model
|
||||
if self.args.save or (epoch + 1 == self.epochs):
|
||||
self.save_model()
|
||||
self.trigger_callbacks('on_model_save')
|
||||
|
||||
@ -319,7 +319,7 @@ class BaseTrainer:
|
||||
# TODO: termination condition
|
||||
|
||||
if rank in {-1, 0}:
|
||||
# do the last evaluation with best.pt
|
||||
# Do final val with best.pt
|
||||
self.log(f'\n{epoch - self.start_epoch + 1} epochs completed in '
|
||||
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
|
||||
self.final_eval()
|
||||
|
Reference in New Issue
Block a user