ultralytics 8.0.48
Edge TPU fix and Metrics updates (#1171)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: majid nasiri <majnasai@gmail.com>
This commit is contained in:
@ -243,15 +243,12 @@ class Exporter:
|
||||
if coreml: # CoreML
|
||||
f[4], _ = self._export_coreml()
|
||||
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
|
||||
LOGGER.warning('WARNING ⚠️ YOLOv8 TensorFlow export is still under development. '
|
||||
'Please consider contributing to the effort if you have TF expertise. Thank you!')
|
||||
nms = False
|
||||
self.args.int8 |= edgetpu
|
||||
f[5], s_model = self._export_saved_model()
|
||||
if pb or tfjs: # pb prerequisite to tfjs
|
||||
f[6], _ = self._export_pb(s_model)
|
||||
if tflite:
|
||||
f[7], _ = self._export_tflite(s_model, nms=nms, agnostic_nms=self.args.agnostic_nms)
|
||||
f[7], _ = self._export_tflite(s_model, nms=False, agnostic_nms=self.args.agnostic_nms)
|
||||
if edgetpu:
|
||||
f[8], _ = self._export_edgetpu(tflite_model=str(
|
||||
Path(f[5]) / (self.file.stem + '_full_integer_quant.tflite'))) # int8 in/out
|
||||
@ -619,20 +616,18 @@ class Exporter:
|
||||
@try_export
|
||||
def _export_edgetpu(self, tflite_model='', prefix=colorstr('Edge TPU:')):
|
||||
# YOLOv8 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
|
||||
LOGGER.warning(f'{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185')
|
||||
|
||||
cmd = 'edgetpu_compiler --version'
|
||||
help_url = 'https://coral.ai/docs/edgetpu/compiler/'
|
||||
assert LINUX, f'export only supported on Linux. See {help_url}'
|
||||
if subprocess.run(f'{cmd} > /dev/null', shell=True).returncode != 0:
|
||||
if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0:
|
||||
LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
|
||||
sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
|
||||
for c in (
|
||||
# 'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -', # errors
|
||||
'wget --no-check-certificate -q -O - https://packages.cloud.google.com/apt/doc/apt-key.gpg | '
|
||||
'sudo apt-key add -',
|
||||
'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | ' # no comma
|
||||
'sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
|
||||
'sudo apt-get update',
|
||||
'sudo apt-get install edgetpu-compiler'):
|
||||
'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
|
||||
'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
|
||||
'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):
|
||||
subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
|
||||
ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
|
||||
|
||||
|
@ -43,7 +43,7 @@ class YOLO:
|
||||
cfg (str): The model configuration if loaded from *.yaml file.
|
||||
ckpt_path (str): The checkpoint file path.
|
||||
overrides (dict): Overrides for the trainer object.
|
||||
metrics_data (Any): The data for metrics.
|
||||
metrics (Any): The data for metrics.
|
||||
|
||||
Methods:
|
||||
__call__(source=None, stream=False, **kwargs):
|
||||
@ -67,7 +67,7 @@ class YOLO:
|
||||
list(ultralytics.yolo.engine.results.Results): The prediction results.
|
||||
"""
|
||||
|
||||
def __init__(self, model='yolov8n.pt', task=None) -> None:
|
||||
def __init__(self, model='yolov8n.pt', task=None, session=None) -> None:
|
||||
"""
|
||||
Initializes the YOLO model.
|
||||
|
||||
@ -83,7 +83,8 @@ class YOLO:
|
||||
self.cfg = None # if loaded from *.yaml
|
||||
self.ckpt_path = None
|
||||
self.overrides = {} # overrides for trainer object
|
||||
self.metrics_data = None
|
||||
self.metrics = None # validation/training metrics
|
||||
self.session = session # HUB session
|
||||
|
||||
# Load or create new YOLO model
|
||||
suffix = Path(model).suffix
|
||||
@ -184,6 +185,7 @@ class YOLO:
|
||||
self._check_is_pytorch_model()
|
||||
self.model.fuse()
|
||||
|
||||
@smart_inference_mode()
|
||||
def predict(self, source=None, stream=False, **kwargs):
|
||||
"""
|
||||
Perform prediction using the YOLO model.
|
||||
@ -217,7 +219,6 @@ class YOLO:
|
||||
is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')
|
||||
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
||||
|
||||
@smart_inference_mode()
|
||||
def track(self, source=None, stream=False, **kwargs):
|
||||
from ultralytics.tracker import register_tracker
|
||||
register_tracker(self)
|
||||
@ -252,7 +253,7 @@ class YOLO:
|
||||
|
||||
validator = TASK_MAP[self.task][2](args=args)
|
||||
validator(model=self.model)
|
||||
self.metrics_data = validator.metrics
|
||||
self.metrics = validator.metrics
|
||||
|
||||
return validator.metrics
|
||||
|
||||
@ -314,12 +315,13 @@ class YOLO:
|
||||
if not overrides.get('resume'): # manually set model only if not resuming
|
||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||
self.model = self.trainer.model
|
||||
self.trainer.hub_session = self.session # attach optional HUB session
|
||||
self.trainer.train()
|
||||
# update model and cfg after training
|
||||
if RANK in {0, -1}:
|
||||
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
|
||||
self.overrides = self.model.args
|
||||
self.metrics_data = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
|
||||
self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
|
||||
|
||||
def to(self, device):
|
||||
"""
|
||||
@ -352,15 +354,6 @@ class YOLO:
|
||||
"""
|
||||
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
||||
|
||||
@property
|
||||
def metrics(self):
|
||||
"""
|
||||
Returns metrics if computed
|
||||
"""
|
||||
if not self.metrics_data:
|
||||
LOGGER.info('No metrics data found! Run training or validation operation first.')
|
||||
return self.metrics_data
|
||||
|
||||
@staticmethod
|
||||
def add_callback(event: str, func):
|
||||
"""
|
||||
|
@ -139,7 +139,8 @@ class Results:
|
||||
annotator.masks(masks.data, colors=[colors(x, True) for x in boxes.cls], im_gpu=im)
|
||||
|
||||
if logits is not None:
|
||||
top5i = logits.argsort(0, descending=True)[:5].tolist() # top 5 indices
|
||||
n5 = min(len(self.names), 5)
|
||||
top5i = logits.argsort(0, descending=True)[:n5].tolist() # top 5 indices
|
||||
text = f"{', '.join(f'{names[j] if names else j} {logits[j]:.2f}' for j in top5i)}, "
|
||||
annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors
|
||||
|
||||
|
Reference in New Issue
Block a user