ultralytics 8.0.29
DDP-cls and default arg fixes (#813)
This commit is contained in:
@ -184,9 +184,6 @@ class Exporter:
|
||||
y = model(im) # dry runs
|
||||
if self.args.half and not coreml and not xml:
|
||||
im, model = im.half(), model.half() # to FP16
|
||||
shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
|
||||
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} and "
|
||||
f"output shape {shape} ({file_size(file):.1f} MB)")
|
||||
|
||||
# Warnings
|
||||
warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
|
||||
@ -207,6 +204,9 @@ class Exporter:
|
||||
'stride': int(max(model.stride)),
|
||||
'names': model.names} # model metadata
|
||||
|
||||
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} and "
|
||||
f"output shape {self.output_shape} ({file_size(file):.1f} MB)")
|
||||
|
||||
# Exports
|
||||
f = [''] * len(fmts) # exported filenames
|
||||
if jit: # TorchScript
|
||||
@ -220,9 +220,8 @@ class Exporter:
|
||||
if coreml: # CoreML
|
||||
f[4], _ = self._export_coreml()
|
||||
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
|
||||
raise NotImplementedError('YOLOv8 TensorFlow export support is still under development. '
|
||||
'Please consider contributing to the effort if you have TF expertise. Thank you!')
|
||||
assert not isinstance(model, ClassificationModel), 'ClassificationModel TF exports not yet supported.'
|
||||
LOGGER.warning('WARNING ⚠️ YOLOv8 TensorFlow export support is still under development. '
|
||||
'Please consider contributing to the effort if you have TF expertise. Thank you!')
|
||||
nms = False
|
||||
f[5], s_model = self._export_saved_model(nms=nms or self.args.agnostic_nms or tfjs,
|
||||
agnostic_nms=self.args.agnostic_nms or tfjs)
|
||||
@ -236,7 +235,7 @@ class Exporter:
|
||||
agnostic_nms=self.args.agnostic_nms)
|
||||
if edgetpu:
|
||||
f[8], _ = self._export_edgetpu()
|
||||
self._add_tflite_metadata(f[8] or f[7], num_outputs=len(s_model.outputs))
|
||||
self._add_tflite_metadata(f[8] or f[7], num_outputs=len(self.output_shape))
|
||||
if tfjs:
|
||||
f[9], _ = self._export_tfjs()
|
||||
if paddle: # PaddlePaddle
|
||||
@ -552,13 +551,13 @@ class Exporter:
|
||||
return f, keras_model
|
||||
|
||||
@try_export
|
||||
def _export_pb(self, keras_model, file, prefix=colorstr('TensorFlow GraphDef:')):
|
||||
def _export_pb(self, keras_model, prefix=colorstr('TensorFlow GraphDef:')):
|
||||
# YOLOv8 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
|
||||
import tensorflow as tf # noqa
|
||||
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||
f = file.with_suffix('.pb')
|
||||
f = self.file.with_suffix('.pb')
|
||||
|
||||
m = tf.function(lambda x: keras_model(x)) # full model
|
||||
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
|
||||
|
@ -119,7 +119,6 @@ class YOLO:
|
||||
def fuse(self):
|
||||
self.model.fuse()
|
||||
|
||||
@smart_inference_mode()
|
||||
def predict(self, source=None, stream=False, **kwargs):
|
||||
"""
|
||||
Perform prediction using the YOLO model.
|
||||
@ -258,8 +257,6 @@ class YOLO:
|
||||
|
||||
@staticmethod
|
||||
def _reset_ckpt_args(args):
|
||||
for arg in 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', 'save_json', \
|
||||
'half', 'v5loader':
|
||||
for arg in 'augment', 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', \
|
||||
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots':
|
||||
args.pop(arg, None)
|
||||
|
||||
args["device"] = '' # set device to '' to prevent auto-DDP usage
|
||||
|
@ -457,7 +457,7 @@ class BaseTrainer:
|
||||
def get_validator(self):
|
||||
raise NotImplementedError("get_validator function not implemented in trainer")
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0):
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
||||
"""
|
||||
Returns dataloader derived from torch.data.Dataloader.
|
||||
"""
|
||||
|
Reference in New Issue
Block a user