@ -5,6 +5,7 @@ Train a model on a dataset
Usage :
Usage :
$ yolo mode = train model = yolov8n . pt data = coco128 . yaml imgsz = 640 epochs = 100 batch = 16
$ yolo mode = train model = yolov8n . pt data = coco128 . yaml imgsz = 640 epochs = 100 batch = 16
"""
"""
import math
import os
import os
import subprocess
import subprocess
import time
import time
@ -14,11 +15,10 @@ from pathlib import Path
import numpy as np
import numpy as np
import torch
import torch
import torch . distributed as dist
from torch import distributed as dist
import torch . nn as nn
from torch import nn , optim
from torch . cuda import amp
from torch . cuda import amp
from torch . nn . parallel import DistributedDataParallel as DDP
from torch . nn . parallel import DistributedDataParallel as DDP
from torch . optim import lr_scheduler
from tqdm import tqdm
from tqdm import tqdm
from ultralytics . nn . tasks import attempt_load_one_weight , attempt_load_weights
from ultralytics . nn . tasks import attempt_load_one_weight , attempt_load_weights
@ -234,33 +234,35 @@ class BaseTrainer:
SyntaxError ( ' batch=-1 to use AutoBatch is only available in Single-GPU training. '
SyntaxError ( ' batch=-1 to use AutoBatch is only available in Single-GPU training. '
' Please pass a valid batch size value for Multi-GPU DDP training, i.e. batch=16 ' )
' Please pass a valid batch size value for Multi-GPU DDP training, i.e. batch=16 ' )
# Dataloaders
batch_size = self . batch_size / / max ( world_size , 1 )
self . train_loader = self . get_dataloader ( self . trainset , batch_size = batch_size , rank = RANK , mode = ' train ' )
if RANK in ( - 1 , 0 ) :
self . test_loader = self . get_dataloader ( self . testset , batch_size = batch_size * 2 , rank = - 1 , mode = ' val ' )
self . validator = self . get_validator ( )
metric_keys = self . validator . metrics . keys + self . label_loss_items ( prefix = ' val ' )
self . metrics = dict ( zip ( metric_keys , [ 0 ] * len ( metric_keys ) ) ) # TODO: init metrics for plot_results()?
self . ema = ModelEMA ( self . model )
if self . args . plots and not self . args . v5loader :
self . plot_training_labels ( )
# Optimizer
# Optimizer
self . accumulate = max ( round ( self . args . nbs / self . batch_size ) , 1 ) # accumulate loss before optimizing
self . accumulate = max ( round ( self . args . nbs / self . batch_size ) , 1 ) # accumulate loss before optimizing
weight_decay = self . args . weight_decay * self . batch_size * self . accumulate / self . args . nbs # scale weight_decay
weight_decay = self . args . weight_decay * self . batch_size * self . accumulate / self . args . nbs # scale weight_decay
iterations = math . ceil ( len ( self . train_loader . dataset ) / max ( self . batch_size , self . args . nbs ) ) * self . epochs
self . optimizer = self . build_optimizer ( model = self . model ,
self . optimizer = self . build_optimizer ( model = self . model ,
name = self . args . optimizer ,
name = self . args . optimizer ,
lr = self . args . lr0 ,
lr = self . args . lr0 ,
momentum = self . args . momentum ,
momentum = self . args . momentum ,
decay = weight_decay )
decay = weight_decay ,
iterations = iterations )
# Scheduler
# Scheduler
if self . args . cos_lr :
if self . args . cos_lr :
self . lf = one_cycle ( 1 , self . args . lrf , self . epochs ) # cosine 1->hyp['lrf']
self . lf = one_cycle ( 1 , self . args . lrf , self . epochs ) # cosine 1->hyp['lrf']
else :
else :
self . lf = lambda x : ( 1 - x / self . epochs ) * ( 1.0 - self . args . lrf ) + self . args . lrf # linear
self . lf = lambda x : ( 1 - x / self . epochs ) * ( 1.0 - self . args . lrf ) + self . args . lrf # linear
self . scheduler = lr_scheduler. LambdaLR ( self . optimizer , lr_lambda = self . lf )
self . scheduler = optim. lr_scheduler. LambdaLR ( self . optimizer , lr_lambda = self . lf )
self . stopper , self . stop = EarlyStopping ( patience = self . args . patience ) , False
self . stopper , self . stop = EarlyStopping ( patience = self . args . patience ) , False
# Dataloaders
batch_size = self . batch_size / / world_size if world_size > 1 else self . batch_size
self . train_loader = self . get_dataloader ( self . trainset , batch_size = batch_size , rank = RANK , mode = ' train ' )
if RANK in ( - 1 , 0 ) :
self . test_loader = self . get_dataloader ( self . testset , batch_size = batch_size * 2 , rank = - 1 , mode = ' val ' )
self . validator = self . get_validator ( )
metric_keys = self . validator . metrics . keys + self . label_loss_items ( prefix = ' val ' )
self . metrics = dict ( zip ( metric_keys , [ 0 ] * len ( metric_keys ) ) ) # TODO: init metrics for plot_results()?
self . ema = ModelEMA ( self . model )
if self . args . plots and not self . args . v5loader :
self . plot_training_labels ( )
self . resume_training ( ckpt )
self . resume_training ( ckpt )
self . scheduler . last_epoch = self . start_epoch - 1 # do not move
self . scheduler . last_epoch = self . start_epoch - 1 # do not move
self . run_callbacks ( ' on_pretrain_routine_end ' )
self . run_callbacks ( ' on_pretrain_routine_end ' )
@ -603,24 +605,30 @@ class BaseTrainer:
if hasattr ( self . train_loader . dataset , ' close_mosaic ' ) :
if hasattr ( self . train_loader . dataset , ' close_mosaic ' ) :
self . train_loader . dataset . close_mosaic ( hyp = self . args )
self . train_loader . dataset . close_mosaic ( hyp = self . args )
@staticmethod
def build_optimizer ( self , model , name = ' auto ' , lr = 0.001 , momentum = 0.9 , decay = 1e-5 , iterations = 1e5 ) :
def build_optimizer ( model , name = ' Adam ' , lr = 0.001 , momentum = 0.9 , decay = 1e-5 ) :
"""
"""
Builds an optimizer with the specified parameters and parameter groups .
Constructs an optimizer for the given model , based on the specified optimizer name , learning rate ,
momentum , weight decay , and number of iterations .
Args :
Args :
model ( nn . Module ) : model to optimize
model ( torch . nn . Module ) : The model for which to build an optimizer .
name ( str ) : name of the optimizer to use
name ( str , optional ) : The name of the optimizer to use . If ' auto ' , the optimizer is selected
lr ( float ) : learning rate
based on the number of iterations . Default : ' auto ' .
momentum ( float ) : momentum
lr ( float , optional ) : The learning rate for the optimizer . Default : 0.001 .
decay ( float ) : weight decay
momentum ( float , optional ) : The momentum factor for the optimizer . Default : 0.9 .
decay ( float , optional ) : The weight decay for the optimizer . Default : 1e-5 .
iterations ( float , optional ) : The number of iterations , which determines the optimizer if
name is ' auto ' . Default : 1e5 .
Returns :
Returns :
optimizer ( torch . optim . Optimizer ) : the built optimizer
( torch . optim . Optimizer ) : The constructed optimizer .
"""
"""
g = [ ] , [ ] , [ ] # optimizer parameter groups
g = [ ] , [ ] , [ ] # optimizer parameter groups
bn = tuple ( v for k , v in nn . __dict__ . items ( ) if ' Norm ' in k ) # normalization layers, i.e. BatchNorm2d()
bn = tuple ( v for k , v in nn . __dict__ . items ( ) if ' Norm ' in k ) # normalization layers, i.e. BatchNorm2d()
if name == ' auto ' :
name , lr , momentum = ( ' SGD ' , 0.01 , 0.9 ) if iterations > 6000 else ( ' NAdam ' , 0.001 , 0.9 )
self . args . warmup_bias_lr = 0.0 # no higher than 0.01 for NAdam
for module_name , module in model . named_modules ( ) :
for module_name , module in model . named_modules ( ) :
for param_name , param in module . named_parameters ( recurse = False ) :
for param_name , param in module . named_parameters ( recurse = False ) :
@ -632,19 +640,21 @@ class BaseTrainer:
else : # weight (with decay)
else : # weight (with decay)
g [ 0 ] . append ( param )
g [ 0 ] . append ( param )
if name == ' Adam ' :
if name in ( ' Adam ' , ' Adamax ' , ' AdamW ' , ' NAdam ' , ' RAdam ' ) :
optimizer = torch . optim . Adam ( g [ 2 ] , lr = lr , betas = ( momentum , 0.999 ) ) # adjust beta1 to momentum
optimizer = getattr ( optim , name , optim . Adam ) ( g [ 2 ] , lr = lr , betas = ( momentum , 0.999 ) , weight_decay = 0.0 )
elif name == ' AdamW ' :
optimizer = torch . optim . AdamW ( g [ 2 ] , lr = lr , betas = ( momentum , 0.999 ) , weight_decay = 0.0 )
elif name == ' RMSProp ' :
elif name == ' RMSProp ' :
optimizer = torch. optim. RMSprop ( g [ 2 ] , lr = lr , momentum = momentum )
optimizer = optim . RMSprop ( g [ 2 ] , lr = lr , momentum = momentum )
elif name == ' SGD ' :
elif name == ' SGD ' :
optimizer = torch. optim. SGD ( g [ 2 ] , lr = lr , momentum = momentum , nesterov = True )
optimizer = optim. SGD ( g [ 2 ] , lr = lr , momentum = momentum , nesterov = True )
else :
else :
raise NotImplementedError ( f ' Optimizer { name } not implemented. ' )
raise NotImplementedError (
f " Optimizer ' { name } ' not found in list of available optimizers "
f ' [Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]. '
' To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics. ' )
optimizer . add_param_group ( { ' params ' : g [ 0 ] , ' weight_decay ' : decay } ) # add g0 with weight_decay
optimizer . add_param_group ( { ' params ' : g [ 0 ] , ' weight_decay ' : decay } ) # add g0 with weight_decay
optimizer . add_param_group ( { ' params ' : g [ 1 ] , ' weight_decay ' : 0.0 } ) # add g1 (BatchNorm2d weights)
optimizer . add_param_group ( { ' params ' : g [ 1 ] , ' weight_decay ' : 0.0 } ) # add g1 (BatchNorm2d weights)
LOGGER . info ( f " { colorstr ( ' optimizer: ' ) } { type ( optimizer ) . __name__ } (lr= { lr } ) with parameter groups "
LOGGER . info (
f ' { len ( g [ 1 ] ) } weight(decay=0.0), { len ( g [ 0 ] ) } weight(decay= { decay } ), { len ( g [ 2 ] ) } bias ' )
f " { colorstr ( ' optimizer: ' ) } { type ( optimizer ) . __name__ } (lr= { lr } , momentum= { momentum } ) with parameter groups "
f ' { len ( g [ 1 ] ) } weight(decay=0.0), { len ( g [ 0 ] ) } weight(decay= { decay } ), { len ( g [ 2 ] ) } bias(decay=0.0) ' )
return optimizer
return optimizer