ultralytics 8.0.155
allow imgsz
and batch
resume changes (#4366)
Co-authored-by: Mostafa Nemati <58460889+monemati@users.noreply.github.com> Co-authored-by: Eduard Voiculescu <eduardvoiculescu95@gmail.com>
This commit is contained in:
@ -249,11 +249,11 @@ class Exporter:
|
||||
f[4], _ = self.export_coreml()
|
||||
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
|
||||
self.args.int8 |= edgetpu
|
||||
f[5], s_model = self.export_saved_model()
|
||||
f[5], keras_model = self.export_saved_model()
|
||||
if pb or tfjs: # pb prerequisite to tfjs
|
||||
f[6], _ = self.export_pb(s_model)
|
||||
f[6], _ = self.export_pb(keras_model=keras_model)
|
||||
if tflite:
|
||||
f[7], _ = self.export_tflite(s_model, nms=False, agnostic_nms=self.args.agnostic_nms)
|
||||
f[7], _ = self.export_tflite(keras_model=keras_model, nms=False, agnostic_nms=self.args.agnostic_nms)
|
||||
if edgetpu:
|
||||
f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f'{self.file.stem}_full_integer_quant.tflite')
|
||||
if tfjs:
|
||||
@ -671,10 +671,7 @@ class Exporter:
|
||||
for file in f.rglob('*.tflite'):
|
||||
f.unlink() if 'quant_with_int16_act.tflite' in str(f) else self._add_tflite_metadata(file)
|
||||
|
||||
# Load saved_model
|
||||
keras_model = tf.saved_model.load(f, tags=None, options=None)
|
||||
|
||||
return str(f), keras_model
|
||||
return str(f), tf.saved_model.load(f, tags=None, options=None) # load saved_model as Keras model
|
||||
|
||||
@try_export
|
||||
def export_pb(self, keras_model, prefix=colorstr('TensorFlow GraphDef:')):
|
||||
|
@ -81,7 +81,7 @@ class BaseTrainer:
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
self.check_resume()
|
||||
self.check_resume(overrides)
|
||||
self.device = select_device(self.args.device, self.args.batch)
|
||||
self.validator = None
|
||||
self.model = None
|
||||
@ -576,7 +576,7 @@ class BaseTrainer:
|
||||
self.metrics.pop('fitness', None)
|
||||
self.run_callbacks('on_fit_epoch_end')
|
||||
|
||||
def check_resume(self):
|
||||
def check_resume(self, overrides):
|
||||
"""Check if resume checkpoint exists and update arguments accordingly."""
|
||||
resume = self.args.resume
|
||||
if resume:
|
||||
@ -589,8 +589,13 @@ class BaseTrainer:
|
||||
if not Path(ckpt_args['data']).exists():
|
||||
ckpt_args['data'] = self.args.data
|
||||
|
||||
resume = True
|
||||
self.args = get_cfg(ckpt_args)
|
||||
self.args.model, resume = str(last), True # reinstate
|
||||
self.args.model = str(last) # reinstate model
|
||||
for k in 'imgsz', 'batch': # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
|
||||
if k in overrides:
|
||||
setattr(self.args, k, overrides[k])
|
||||
|
||||
except Exception as e:
|
||||
raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '
|
||||
"i.e. 'yolo train resume model=path/to/last.pt'") from e
|
||||
|
Reference in New Issue
Block a user