ultralytics 8.0.145
Windows URL fix and Pose MPS warning (#4034)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -135,11 +135,12 @@ class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
def plot_training_samples(self, batch, ni):
|
||||
"""Plots training samples with their annotations."""
|
||||
plot_images(images=batch['img'],
|
||||
batch_idx=torch.arange(len(batch['img'])),
|
||||
cls=batch['cls'].squeeze(-1),
|
||||
fname=self.save_dir / f'train_batch{ni}.jpg',
|
||||
on_plot=self.on_plot)
|
||||
plot_images(
|
||||
images=batch['img'],
|
||||
batch_idx=torch.arange(len(batch['img'])),
|
||||
cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models
|
||||
fname=self.save_dir / f'train_batch{ni}.jpg',
|
||||
on_plot=self.on_plot)
|
||||
|
||||
|
||||
def train(cfg=DEFAULT_CFG, use_python=False):
|
||||
|
@ -74,12 +74,13 @@ class ClassificationValidator(BaseValidator):
|
||||
|
||||
def plot_val_samples(self, batch, ni):
|
||||
"""Plot validation image samples."""
|
||||
plot_images(images=batch['img'],
|
||||
batch_idx=torch.arange(len(batch['img'])),
|
||||
cls=batch['cls'].squeeze(-1),
|
||||
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
|
||||
names=self.names,
|
||||
on_plot=self.on_plot)
|
||||
plot_images(
|
||||
images=batch['img'],
|
||||
batch_idx=torch.arange(len(batch['img'])),
|
||||
cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models
|
||||
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
|
||||
names=self.names,
|
||||
on_plot=self.on_plot)
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
"""Plots predicted bounding boxes on input images and saves the result."""
|
||||
|
Reference in New Issue
Block a user