ultralytics 8.0.29
DDP-cls and default arg fixes (#813)
This commit is contained in:
@ -1,5 +1,4 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
@ -9,7 +8,7 @@ from ultralytics.yolo import v8
|
||||
from ultralytics.yolo.data import build_classification_dataloader
|
||||
from ultralytics.yolo.engine.trainer import BaseTrainer
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG
|
||||
from ultralytics.yolo.utils.torch_utils import strip_optimizer
|
||||
from ultralytics.yolo.utils.torch_utils import strip_optimizer, is_parallel
|
||||
|
||||
|
||||
class ClassificationTrainer(BaseTrainer):
|
||||
@ -56,7 +55,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
||||
if model.endswith(".pt"):
|
||||
self.model, _ = attempt_load_one_weight(model, device='cpu')
|
||||
for p in model.parameters():
|
||||
for p in self.model.parameters():
|
||||
p.requires_grad = True # for training
|
||||
elif model.endswith(".yaml"):
|
||||
self.model = self.get_model(cfg=model)
|
||||
@ -75,8 +74,12 @@ class ClassificationTrainer(BaseTrainer):
|
||||
augment=mode == "train",
|
||||
rank=rank,
|
||||
workers=self.args.workers)
|
||||
# Attach inference transforms
|
||||
if mode != "train":
|
||||
self.model.transforms = loader.dataset.torch_transforms # attach inference transforms
|
||||
if is_parallel(self.model):
|
||||
self.model.module.transforms = loader.dataset.torch_transforms
|
||||
else:
|
||||
self.model.transforms = loader.dataset.torch_transforms
|
||||
return loader
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
|
Reference in New Issue
Block a user