General cleanup (#69)
Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
@ -2,18 +2,37 @@ import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||
from ultralytics.yolo import v8
|
||||
from ultralytics.yolo.data import build_dataloader
|
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
|
||||
from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE
|
||||
from ultralytics.yolo.utils.modeling.tasks import DetectionModel
|
||||
from ultralytics.yolo.utils.plotting import plot_images, plot_results
|
||||
from ultralytics.yolo.utils.torch_utils import de_parallel
|
||||
|
||||
from ..segment import SegmentationTrainer
|
||||
from .val import DetectionValidator
|
||||
|
||||
|
||||
# BaseTrainer python usage
|
||||
class DetectionTrainer(SegmentationTrainer):
|
||||
class DetectionTrainer(BaseTrainer):
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size, mode="train", rank=0):
|
||||
# TODO: manage splits differently
|
||||
# calculate stride - check if model is initialized
|
||||
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||
return build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode)[0]
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
||||
return batch
|
||||
|
||||
def set_model_attributes(self):
|
||||
nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)
|
||||
self.args.box *= 3 / nl # scale to layers
|
||||
self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
|
||||
self.args.obj *= (self.args.img_size / 640) ** 2 * 3 / nl # scale to image size and layers
|
||||
self.model.nc = self.data["nc"] # attach number of classes to model
|
||||
self.model.args = self.args # attach hyperparameters to model
|
||||
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
|
||||
self.model.names = self.data["names"]
|
||||
|
||||
def load_model(self, model_cfg=None, weights=None):
|
||||
model = DetectionModel(model_cfg or weights["model"].yaml,
|
||||
@ -27,7 +46,10 @@ class DetectionTrainer(SegmentationTrainer):
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
return DetectionValidator(self.test_loader, save_dir=self.save_dir, logger=self.console, args=self.args)
|
||||
return v8.detect.DetectionValidator(self.test_loader,
|
||||
save_dir=self.save_dir,
|
||||
logger=self.console,
|
||||
args=self.args)
|
||||
|
||||
def criterion(self, preds, batch):
|
||||
head = de_parallel(self.model).model[-1]
|
||||
|
||||
Reference in New Issue
Block a user