`ultralytics 8.0.145` Windows URL fix and Pose MPS warning (#4034)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Glenn Jocher 1 year ago committed by GitHub
parent 99cb549efe
commit a02b7e6273
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -43,5 +43,5 @@ ENV OMP_NUM_THREADS=1
# Run
# t=ultralytics/ultralytics:latest-jetson && sudo docker run -it --ipc=host $t
# Pull and Run with local volume mounted
# t=ultralytics/ultralytics:jetson && sudo docker pull $t && sudo docker run -it --runtime=nvidia $t
# Pull and Run with NVIDIA runtime
# t=ultralytics/ultralytics:latest-jetson && sudo docker pull $t && sudo docker run -it --ipc=host --runtime=nvidia $t

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.144'
__version__ = '8.0.145'
from ultralytics.hub import start
from ultralytics.models import RTDETR, SAM, YOLO

@ -244,7 +244,7 @@ def check_det_dataset(dataset, autodownload=True):
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
if not all(x.exists() for x in val):
name = clean_url(dataset) # dataset name with URL auth stripped
m = f"\nDataset '{name}' images not found ⚠️, missing paths %s" % [str(x) for x in val if not x.exists()]
m = f"\nDataset '{name}' images not found ⚠️, missing path '{[x for x in val if not x.exists()][0]}'"
if s and autodownload:
LOGGER.warning(m)
else:

@ -135,11 +135,12 @@ class ClassificationTrainer(BaseTrainer):
def plot_training_samples(self, batch, ni):
"""Plots training samples with their annotations."""
plot_images(images=batch['img'],
batch_idx=torch.arange(len(batch['img'])),
cls=batch['cls'].squeeze(-1),
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
plot_images(
images=batch['img'],
batch_idx=torch.arange(len(batch['img'])),
cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
def train(cfg=DEFAULT_CFG, use_python=False):

@ -74,12 +74,13 @@ class ClassificationValidator(BaseValidator):
def plot_val_samples(self, batch, ni):
"""Plot validation image samples."""
plot_images(images=batch['img'],
batch_idx=torch.arange(len(batch['img'])),
cls=batch['cls'].squeeze(-1),
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names,
on_plot=self.on_plot)
plot_images(
images=batch['img'],
batch_idx=torch.arange(len(batch['img'])),
cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result."""

@ -2,7 +2,7 @@
from ultralytics.engine.results import Results
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, ROOT, ops
from ultralytics.utils import DEFAULT_CFG, LOGGER, ROOT, ops
class PosePredictor(DetectionPredictor):
@ -10,6 +10,9 @@ class PosePredictor(DetectionPredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
super().__init__(cfg, overrides, _callbacks)
self.args.task = 'pose'
if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
'See https://github.com/ultralytics/ultralytics/issues/4031.')
def postprocess(self, preds, img, orig_imgs):
"""Return detection results for a given input image or list of images."""

@ -4,7 +4,7 @@ from copy import copy
from ultralytics.models import yolo
from ultralytics.nn.tasks import PoseModel
from ultralytics.utils import DEFAULT_CFG
from ultralytics.utils import DEFAULT_CFG, LOGGER
from ultralytics.utils.plotting import plot_images, plot_results
@ -18,6 +18,10 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
overrides['task'] = 'pose'
super().__init__(cfg, overrides, _callbacks)
if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
'See https://github.com/ultralytics/ultralytics/issues/4031.')
def get_model(self, cfg=None, weights=None, verbose=True):
"""Get pose estimation model with specified configuration and weights."""
model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose)

@ -19,6 +19,9 @@ class PoseValidator(DetectionValidator):
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'pose'
self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
'See https://github.com/ultralytics/ultralytics/issues/4031.')
def preprocess(self, batch):
"""Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""

@ -824,7 +824,7 @@ def deprecation_warn(arg, new_arg, version=None):
def clean_url(url):
"""Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt."""
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
url = Path(url).as_posix().replace(':/', '://') # Pathlib turns :// -> :/, as_posix() for Windows
return urllib.parse.unquote(url).split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth

Loading…
Cancel
Save