Simplify argument names (#141)
This commit is contained in:
@ -164,14 +164,14 @@ class Exporter:
|
||||
assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic'
|
||||
|
||||
# Checks
|
||||
# if self.args.batch_size == model.args['batch_size']: # user has not modified training batch_size
|
||||
self.args.batch_size = 1
|
||||
# if self.args.batch == model.args['batch_size']: # user has not modified training batch_size
|
||||
self.args.batch = 1
|
||||
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
|
||||
if self.args.optimize:
|
||||
assert self.device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
|
||||
|
||||
# Input
|
||||
im = torch.zeros(self.args.batch_size, 3, *self.imgsz).to(self.device)
|
||||
im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
|
||||
file = Path(getattr(model, 'pt_path', None) or model.yaml['yaml_file'])
|
||||
if file.suffix == '.yaml':
|
||||
file = Path(file.name)
|
||||
|
@ -102,7 +102,7 @@ class BaseTrainer:
|
||||
yaml_save(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True)) # save run args
|
||||
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
|
||||
|
||||
self.batch_size = self.args.batch_size
|
||||
self.batch_size = self.args.batch
|
||||
self.epochs = self.args.epochs
|
||||
self.start_epoch = 0
|
||||
if RANK == -1:
|
||||
|
@ -87,18 +87,18 @@ class BaseValidator:
|
||||
callbacks.add_integration_callbacks(self)
|
||||
self.run_callbacks('on_val_start')
|
||||
assert model is not None, "Either trainer or model is needed for validation"
|
||||
self.device = select_device(self.args.device, self.args.batch_size)
|
||||
self.device = select_device(self.args.device, self.args.batch)
|
||||
self.args.half &= self.device.type != 'cpu'
|
||||
model = AutoBackend(model, device=self.device, dnn=self.args.dnn, fp16=self.args.half)
|
||||
self.model = model
|
||||
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
|
||||
imgsz = check_imgsz(self.args.imgsz, stride=stride)
|
||||
if engine:
|
||||
self.args.batch_size = model.batch_size
|
||||
self.args.batch = model.batch_size
|
||||
else:
|
||||
self.device = model.device
|
||||
if not pt and not jit:
|
||||
self.args.batch_size = 1 # export.py models default to batch-size 1
|
||||
self.args.batch = 1 # export.py models default to batch-size 1
|
||||
self.logger.info(
|
||||
f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
||||
|
||||
@ -110,7 +110,7 @@ class BaseValidator:
|
||||
if self.device.type == 'cpu':
|
||||
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
||||
self.dataloader = self.dataloader or \
|
||||
self.get_dataloader(data.get("val") or data.set("test"), self.args.batch_size)
|
||||
self.get_dataloader(data.get("val") or data.set("test"), self.args.batch)
|
||||
self.data = data
|
||||
|
||||
model.eval()
|
||||
|
Reference in New Issue
Block a user