ultralytics 8.0.94
HUBDatasetStats() Segment and Pose support (#2450)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: JF Chen <k-2feng@hotmail.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
@ -72,9 +72,8 @@ class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
return # dont return ckpt. Classification doesn't support resume
|
||||
|
||||
def build_dataset(self, img_path, mode='train'):
|
||||
dataset = ClassificationDataset(root=img_path, imgsz=self.args.imgsz, augment=mode == 'train')
|
||||
return dataset
|
||||
def build_dataset(self, img_path, mode='train', batch=None):
|
||||
return ClassificationDataset(root=img_path, imgsz=self.args.imgsz, augment=mode == 'train')
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
||||
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
|
||||
|
@ -46,7 +46,8 @@ class ClassificationValidator(BaseValidator):
|
||||
"""Finalizes metrics of the model such as confusion_matrix and speed."""
|
||||
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
|
||||
for normalize in True, False:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, names=self.names.values(), normalize=normalize)
|
||||
self.metrics.speed = self.speed
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
|
||||
|
@ -32,7 +32,7 @@ class DetectionTrainer(BaseTrainer):
|
||||
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs)
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size, rank=0, mode='train'):
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
||||
"""TODO: manage splits differently."""
|
||||
# Calculate stride - check if model is initialized
|
||||
if self.args.v5loader:
|
||||
@ -62,8 +62,7 @@ class DetectionTrainer(BaseTrainer):
|
||||
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
|
||||
shuffle = False
|
||||
workers = self.args.workers if mode == 'train' else self.args.workers * 2
|
||||
dataloader = build_dataloader(dataset, batch_size, workers, shuffle, rank)
|
||||
return dataloader
|
||||
return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""Preprocesses a batch of images by scaling and converting to float."""
|
||||
|
@ -144,7 +144,8 @@ class DetectionValidator(BaseValidator):
|
||||
LOGGER.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
|
||||
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
|
||||
for normalize in True, False:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, names=self.names.values(), normalize=normalize)
|
||||
|
||||
def _process_batch(self, detections, labels):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user