update model initialization design, supports custom data/num_classes (#44)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
		
							
								
								
									
										2
									
								
								.github/workflows/ci.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci.yaml
									
									
									
									
										vendored
									
									
								
							@ -94,7 +94,7 @@ jobs:
 | 
			
		||||
      - name: Test segmentation
 | 
			
		||||
        shell: bash  # for Windows compatibility
 | 
			
		||||
        run: |
 | 
			
		||||
          python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-seg.yaml epochs=1 img_size=64
 | 
			
		||||
          python ultralytics/yolo/v8/segment/train.py model=yolov5n-seg.yaml data=coco128-seg.yaml epochs=1 img_size=64
 | 
			
		||||
      - name: Test classification
 | 
			
		||||
        shell: bash  # for Windows compatibility
 | 
			
		||||
        run: |
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -131,3 +131,4 @@ dmypy.json
 | 
			
		||||
# datasets and projects
 | 
			
		||||
datasets/
 | 
			
		||||
ultralytics-yolo/
 | 
			
		||||
runs/
 | 
			
		||||
@ -63,10 +63,8 @@ class BaseTrainer:
 | 
			
		||||
        else:
 | 
			
		||||
            self.data = check_dataset(self.data)
 | 
			
		||||
        self.trainset, self.testset = self.get_dataset(self.data)
 | 
			
		||||
        if self.args.cfg is not None:
 | 
			
		||||
            self.model = self.load_cfg(check_file(self.args.cfg))
 | 
			
		||||
        if self.args.model is not None:
 | 
			
		||||
            self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device)
 | 
			
		||||
        if self.args.model:
 | 
			
		||||
            self.model = self.get_model(self.args.model, self.data)
 | 
			
		||||
 | 
			
		||||
        # epoch level metrics
 | 
			
		||||
        self.metrics = {}  # handle metrics returned by validator
 | 
			
		||||
@ -261,20 +259,20 @@ class BaseTrainer:
 | 
			
		||||
        """
 | 
			
		||||
        return data["train"], data["val"]
 | 
			
		||||
 | 
			
		||||
    def get_model(self, model, pretrained):
 | 
			
		||||
    def get_model(self, model: str, data: Dict):
 | 
			
		||||
        """
 | 
			
		||||
        load/create/download model for any task
 | 
			
		||||
        """
 | 
			
		||||
        model = get_model(model)
 | 
			
		||||
        for m in model.modules():
 | 
			
		||||
            if not pretrained and hasattr(m, 'reset_parameters'):
 | 
			
		||||
                m.reset_parameters()
 | 
			
		||||
        for p in model.parameters():
 | 
			
		||||
            p.requires_grad = True
 | 
			
		||||
 | 
			
		||||
        pretrained = False
 | 
			
		||||
        if not str(model).endswith(".yaml"):
 | 
			
		||||
            pretrained = True
 | 
			
		||||
            weights = get_model(model)  # rename this to something less confusing?
 | 
			
		||||
        model = self.load_model(model_cfg=model if not pretrained else None,
 | 
			
		||||
                                weights=weights if pretrained else None,
 | 
			
		||||
                                data=self.data)
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
    def load_cfg(self, cfg):
 | 
			
		||||
    def load_model(self, model_cfg, weights, data):
 | 
			
		||||
        raise NotImplementedError("This task trainer doesn't support loading cfg files")
 | 
			
		||||
 | 
			
		||||
    def get_validator(self):
 | 
			
		||||
 | 
			
		||||
@ -3,8 +3,7 @@
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Train settings -------------------------------------------------------------------------------------------------------
 | 
			
		||||
model: null  # i.e. yolov5s.pt
 | 
			
		||||
cfg: null  # i.e. yolov5s.yaml
 | 
			
		||||
model: null  # i.e. yolov5s.pt, yolo.yaml
 | 
			
		||||
data: null  # i.e. coco128.yaml
 | 
			
		||||
epochs: 300
 | 
			
		||||
batch_size: 16
 | 
			
		||||
@ -70,6 +69,7 @@ mosaic: 1.0  # image mosaic (probability)
 | 
			
		||||
mixup: 0.0  # image mixup (probability)
 | 
			
		||||
copy_paste: 0.0  # segment copy-paste (probability)
 | 
			
		||||
label_smoothing: 0.0
 | 
			
		||||
# anchors: 3
 | 
			
		||||
 | 
			
		||||
# Hydra configs --------------------------------------------------------------------------------------------------------
 | 
			
		||||
hydra:
 | 
			
		||||
 | 
			
		||||
@ -140,8 +140,3 @@ def download(url, dir=Path.cwd(), unzip=True, delete=True, curl=False, threads=1
 | 
			
		||||
    else:
 | 
			
		||||
        for u in [url] if isinstance(url, (str, Path)) else url:
 | 
			
		||||
            download_one(u, dir)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_model(model: str):
 | 
			
		||||
    # check for local weights
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
@ -66,7 +66,7 @@ class BaseModel(nn.Module):
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def load(self, weights):
 | 
			
		||||
        # Force all tasks implement this function
 | 
			
		||||
        # Force all tasks to implement this function
 | 
			
		||||
        raise NotImplementedError("This function needs to be implemented by derived classes!")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -169,10 +169,10 @@ class DetectionModel(BaseModel):
 | 
			
		||||
            mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
 | 
			
		||||
 | 
			
		||||
    def load(self, weights):
 | 
			
		||||
        ckpt = torch.load(weights, map_location='cpu')  # load checkpoint to CPU to avoid CUDA memory leak
 | 
			
		||||
        csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
 | 
			
		||||
        csd = weights['model'].float().state_dict()  # checkpoint state_dict as FP32
 | 
			
		||||
        csd = intersect_state_dicts(csd, self.state_dict())  # intersect
 | 
			
		||||
        self.load_state_dict(csd, strict=False)  # load
 | 
			
		||||
        LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from {weights}')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SegmentationModel(DetectionModel):
 | 
			
		||||
@ -203,11 +203,33 @@ class ClassificationModel(BaseModel):
 | 
			
		||||
        self.nc = nc
 | 
			
		||||
 | 
			
		||||
    def _from_yaml(self, cfg):
 | 
			
		||||
        # Create a YOLOv5 classification model from a *.yaml file
 | 
			
		||||
        # TODO: Create a YOLOv5 classification model from a *.yaml file
 | 
			
		||||
        self.model = None
 | 
			
		||||
 | 
			
		||||
    def load(self, weights):
 | 
			
		||||
        ckpt = torch.load(weights, map_location='cpu')  # load checkpoint to CPU to avoid CUDA memory leak
 | 
			
		||||
        csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
 | 
			
		||||
        model = weights["model"] if isinstance(weights, dict) else weights  # torchvision models are not dicts
 | 
			
		||||
        csd = model.float().state_dict()
 | 
			
		||||
        csd = intersect_state_dicts(csd, self.state_dict())  # intersect
 | 
			
		||||
        self.load_state_dict(csd, strict=False)  # load
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def reshape_outputs(model, nc):
 | 
			
		||||
        # Update a TorchVision classification model to class count 'n' if required
 | 
			
		||||
        from ultralytics.yolo.utils.modeling.modules import Classify
 | 
			
		||||
        name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1]  # last module
 | 
			
		||||
        if isinstance(m, Classify):  # YOLO Classify() head
 | 
			
		||||
            if m.linear.out_features != nc:
 | 
			
		||||
                m.linear = nn.Linear(m.linear.in_features, nc)
 | 
			
		||||
        elif isinstance(m, nn.Linear):  # ResNet, EfficientNet
 | 
			
		||||
            if m.out_features != nc:
 | 
			
		||||
                setattr(model, name, nn.Linear(m.in_features, nc))
 | 
			
		||||
        elif isinstance(m, nn.Sequential):
 | 
			
		||||
            types = [type(x) for x in m]
 | 
			
		||||
            if nn.Linear in types:
 | 
			
		||||
                i = types.index(nn.Linear)  # nn.Linear index
 | 
			
		||||
                if m[i].out_features != nc:
 | 
			
		||||
                    m[i] = nn.Linear(m[i].in_features, nc)
 | 
			
		||||
            elif nn.Conv2d in types:
 | 
			
		||||
                i = types.index(nn.Conv2d)  # nn.Conv2d index
 | 
			
		||||
                if m[i].out_channels != nc:
 | 
			
		||||
                    m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias)
 | 
			
		||||
 | 
			
		||||
@ -1,26 +1,27 @@
 | 
			
		||||
import subprocess
 | 
			
		||||
import time
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import hydra
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from ultralytics.yolo import v8
 | 
			
		||||
from ultralytics.yolo.data import build_classification_dataloader
 | 
			
		||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
 | 
			
		||||
from ultralytics.yolo.utils import colorstr
 | 
			
		||||
from ultralytics.yolo.utils.downloads import download
 | 
			
		||||
from ultralytics.yolo.utils.files import WorkingDirectory
 | 
			
		||||
from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zero_first
 | 
			
		||||
from ultralytics.yolo.utils.modeling.tasks import ClassificationModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# BaseTrainer python usage
 | 
			
		||||
class ClassificationTrainer(BaseTrainer):
 | 
			
		||||
 | 
			
		||||
    def load_model(self, model_cfg, weights, data):
 | 
			
		||||
        # TODO: why treat clf models as unique. We should have clf yamls?
 | 
			
		||||
        if weights and not weights.__class__.__name__.startswith("yolo"):  # torchvision
 | 
			
		||||
            model = weights
 | 
			
		||||
        else:
 | 
			
		||||
            model = ClassificationModel(model_cfg, weights, data["nc"])
 | 
			
		||||
        ClassificationModel.reshape_outputs(model, data["nc"])
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
    def get_dataloader(self, dataset_path, batch_size=None, rank=0):
 | 
			
		||||
        return build_classification_dataloader(path=dataset_path,
 | 
			
		||||
                                               imgsz=self.args.img_size,
 | 
			
		||||
                                               batch_size=self.args.batch_size,
 | 
			
		||||
                                               batch_size=batch_size,
 | 
			
		||||
                                               rank=rank)
 | 
			
		||||
 | 
			
		||||
    def preprocess_batch(self, batch):
 | 
			
		||||
 | 
			
		||||
@ -10,12 +10,11 @@ import torch.nn.functional as F
 | 
			
		||||
from ultralytics.yolo import v8
 | 
			
		||||
from ultralytics.yolo.data import build_dataloader
 | 
			
		||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
 | 
			
		||||
from ultralytics.yolo.utils.downloads import download
 | 
			
		||||
from ultralytics.yolo.utils.files import WorkingDirectory
 | 
			
		||||
from ultralytics.yolo.utils.anchors import check_anchors
 | 
			
		||||
from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE
 | 
			
		||||
from ultralytics.yolo.utils.modeling.tasks import SegmentationModel
 | 
			
		||||
from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy
 | 
			
		||||
from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, de_parallel, torch_distributed_zero_first
 | 
			
		||||
from ultralytics.yolo.utils.torch_utils import de_parallel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# BaseTrainer python usage
 | 
			
		||||
@ -45,8 +44,15 @@ class SegmentationTrainer(BaseTrainer):
 | 
			
		||||
        batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
 | 
			
		||||
        return batch
 | 
			
		||||
 | 
			
		||||
    def load_cfg(self, cfg):
 | 
			
		||||
        return SegmentationModel(cfg, nc=80)
 | 
			
		||||
    def load_model(self, model_cfg, weights, data):
 | 
			
		||||
        model = SegmentationModel(model_cfg if model_cfg else weights["model"].yaml,
 | 
			
		||||
                                  ch=3,
 | 
			
		||||
                                  nc=data["nc"],
 | 
			
		||||
                                  anchors=self.args.get("anchors"))
 | 
			
		||||
        check_anchors(model, self.args.anchor_t, self.args.img_size)
 | 
			
		||||
        if weights:
 | 
			
		||||
            model.load(weights)
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
    def get_validator(self):
 | 
			
		||||
        return v8.segment.SegmentationValidator(self.test_loader, self.device, logger=self.console)
 | 
			
		||||
@ -232,7 +238,7 @@ class SegmentationTrainer(BaseTrainer):
 | 
			
		||||
 | 
			
		||||
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
 | 
			
		||||
def train(cfg):
 | 
			
		||||
    cfg.cfg = v8.ROOT / "models/yolov5n-seg.yaml"
 | 
			
		||||
    cfg.model = v8.ROOT / "models/yolov5n-seg.yaml"
 | 
			
		||||
    cfg.data = cfg.data or "coco128-seg.yaml"  # or yolo.ClassificationDataset("mnist")
 | 
			
		||||
    trainer = SegmentationTrainer(cfg)
 | 
			
		||||
    trainer.train()
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user