ultralytics 8.0.32 HUB and TensorFlow fixes (#870)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-02-09 01:47:34 +04:00
committed by GitHub
parent f5d003d05a
commit c9893810c7
14 changed files with 118 additions and 85 deletions

View File

@ -203,7 +203,7 @@ class Exporter:
self.im = im
self.model = model
self.file = file
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else (x.shape for x in y)
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
self.metadata = {
'description': f"Ultralytics {self.pretty_name} model trained on {self.model.args['data']}",
@ -213,8 +213,8 @@ class Exporter:
'stride': int(max(model.stride)),
'names': model.names} # model metadata
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} and "
f"output shape {self.output_shape} ({file_size(file):.1f} MB)")
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and "
f"output shape(s) {self.output_shape} ({file_size(file):.1f} MB)")
# Exports
f = [''] * len(fmts) # exported filenames
@ -234,19 +234,22 @@ class Exporter:
nms = False
f[5], s_model = self._export_saved_model(nms=nms or self.args.agnostic_nms or tfjs,
agnostic_nms=self.args.agnostic_nms or tfjs)
if pb or tfjs: # pb prerequisite to tfjs
f[6], _ = self._export_pb(s_model)
if tflite or edgetpu:
f[7], _ = self._export_tflite(s_model,
int8=self.args.int8 or edgetpu,
data=self.args.data,
nms=nms,
agnostic_nms=self.args.agnostic_nms)
if edgetpu:
f[8], _ = self._export_edgetpu()
self._add_tflite_metadata(f[8] or f[7], num_outputs=len(self.output_shape))
if tfjs:
f[9], _ = self._export_tfjs()
debug = False
if debug:
if pb or tfjs: # pb prerequisite to tfjs
f[6], _ = self._export_pb(s_model)
if tflite or edgetpu:
f[7], _ = self._export_tflite(s_model,
int8=self.args.int8 or edgetpu,
data=self.args.data,
nms=nms,
agnostic_nms=self.args.agnostic_nms)
if edgetpu:
f[8], _ = self._export_edgetpu()
self._add_tflite_metadata(f[8] or f[7], num_outputs=len(self.output_shape))
if tfjs:
f[9], _ = self._export_tfjs()
if paddle: # PaddlePaddle
f[10], _ = self._export_paddle()

View File

@ -120,7 +120,7 @@ class BaseValidator:
if not pt:
self.args.rect = False
self.dataloader = self.dataloader or \
self.get_dataloader(self.data.get("val") or self.data.set("test"), self.args.batch)
self.get_dataloader(self.data.get("val") or self.data.get("test"), self.args.batch)
model.eval()
model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup