Add Adamax, NAdam, RAdam optimizers (#2969)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Glenn Jocher 1 year ago committed by GitHub
parent f502b50365
commit 451cf8b647
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -55,50 +55,50 @@ include the choice of optimizer, the choice of loss function, and the size and c
is important to carefully tune and experiment with these settings to achieve the best possible performance for a given is important to carefully tune and experiment with these settings to achieve the best possible performance for a given
task. task.
| Key | Value | Description | | Key | Value | Description |
|-------------------|----------|-----------------------------------------------------------------------------| |-------------------|----------|-----------------------------------------------------------------------------------|
| `model` | `None` | path to model file, i.e. yolov8n.pt, yolov8n.yaml | | `model` | `None` | path to model file, i.e. yolov8n.pt, yolov8n.yaml |
| `data` | `None` | path to data file, i.e. coco128.yaml | | `data` | `None` | path to data file, i.e. coco128.yaml |
| `epochs` | `100` | number of epochs to train for | | `epochs` | `100` | number of epochs to train for |
| `patience` | `50` | epochs to wait for no observable improvement for early stopping of training | | `patience` | `50` | epochs to wait for no observable improvement for early stopping of training |
| `batch` | `16` | number of images per batch (-1 for AutoBatch) | | `batch` | `16` | number of images per batch (-1 for AutoBatch) |
| `imgsz` | `640` | size of input images as integer or w,h | | `imgsz` | `640` | size of input images as integer or w,h |
| `save` | `True` | save train checkpoints and predict results | | `save` | `True` | save train checkpoints and predict results |
| `save_period` | `-1` | Save checkpoint every x epochs (disabled if < 1) | | `save_period` | `-1` | Save checkpoint every x epochs (disabled if < 1) |
| `cache` | `False` | True/ram, disk or False. Use cache for data loading | | `cache` | `False` | True/ram, disk or False. Use cache for data loading |
| `device` | `None` | device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu | | `device` | `None` | device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu |
| `workers` | `8` | number of worker threads for data loading (per RANK if DDP) | | `workers` | `8` | number of worker threads for data loading (per RANK if DDP) |
| `project` | `None` | project name | | `project` | `None` | project name |
| `name` | `None` | experiment name | | `name` | `None` | experiment name |
| `exist_ok` | `False` | whether to overwrite existing experiment | | `exist_ok` | `False` | whether to overwrite existing experiment |
| `pretrained` | `False` | whether to use a pretrained model | | `pretrained` | `False` | whether to use a pretrained model |
| `optimizer` | `'SGD'` | optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp'] | | `optimizer` | `'auto'` | optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto] |
| `verbose` | `False` | whether to print verbose output | | `verbose` | `False` | whether to print verbose output |
| `seed` | `0` | random seed for reproducibility | | `seed` | `0` | random seed for reproducibility |
| `deterministic` | `True` | whether to enable deterministic mode | | `deterministic` | `True` | whether to enable deterministic mode |
| `single_cls` | `False` | train multi-class data as single-class | | `single_cls` | `False` | train multi-class data as single-class |
| `rect` | `False` | rectangular training with each batch collated for minimum padding | | `rect` | `False` | rectangular training with each batch collated for minimum padding |
| `cos_lr` | `False` | use cosine learning rate scheduler | | `cos_lr` | `False` | use cosine learning rate scheduler |
| `close_mosaic` | `0` | (int) disable mosaic augmentation for final epochs | | `close_mosaic` | `0` | (int) disable mosaic augmentation for final epochs |
| `resume` | `False` | resume training from last checkpoint | | `resume` | `False` | resume training from last checkpoint |
| `amp` | `True` | Automatic Mixed Precision (AMP) training, choices=[True, False] | | `amp` | `True` | Automatic Mixed Precision (AMP) training, choices=[True, False] |
| `fraction` | `1.0` | dataset fraction to train on (default is 1.0, all images in train set) | | `fraction` | `1.0` | dataset fraction to train on (default is 1.0, all images in train set) |
| `profile` | `False` | profile ONNX and TensorRT speeds during training for loggers | | `profile` | `False` | profile ONNX and TensorRT speeds during training for loggers |
| `lr0` | `0.01` | initial learning rate (i.e. SGD=1E-2, Adam=1E-3) | | `lr0` | `0.01` | initial learning rate (i.e. SGD=1E-2, Adam=1E-3) |
| `lrf` | `0.01` | final learning rate (lr0 * lrf) | | `lrf` | `0.01` | final learning rate (lr0 * lrf) |
| `momentum` | `0.937` | SGD momentum/Adam beta1 | | `momentum` | `0.937` | SGD momentum/Adam beta1 |
| `weight_decay` | `0.0005` | optimizer weight decay 5e-4 | | `weight_decay` | `0.0005` | optimizer weight decay 5e-4 |
| `warmup_epochs` | `3.0` | warmup epochs (fractions ok) | | `warmup_epochs` | `3.0` | warmup epochs (fractions ok) |
| `warmup_momentum` | `0.8` | warmup initial momentum | | `warmup_momentum` | `0.8` | warmup initial momentum |
| `warmup_bias_lr` | `0.1` | warmup initial bias lr | | `warmup_bias_lr` | `0.1` | warmup initial bias lr |
| `box` | `7.5` | box loss gain | | `box` | `7.5` | box loss gain |
| `cls` | `0.5` | cls loss gain (scale with pixels) | | `cls` | `0.5` | cls loss gain (scale with pixels) |
| `dfl` | `1.5` | dfl loss gain | | `dfl` | `1.5` | dfl loss gain |
| `pose` | `12.0` | pose loss gain (pose-only) | | `pose` | `12.0` | pose loss gain (pose-only) |
| `kobj` | `2.0` | keypoint obj loss gain (pose-only) | | `kobj` | `2.0` | keypoint obj loss gain (pose-only) |
| `label_smoothing` | `0.0` | label smoothing (fraction) | | `label_smoothing` | `0.0` | label smoothing (fraction) |
| `nbs` | `64` | nominal batch size | | `nbs` | `64` | nominal batch size |
| `overlap_mask` | `True` | masks should overlap during training (segment train only) | | `overlap_mask` | `True` | masks should overlap during training (segment train only) |
| `mask_ratio` | `4` | mask downsample ratio (segment train only) | | `mask_ratio` | `4` | mask downsample ratio (segment train only) |
| `dropout` | `0.0` | use dropout regularization (classify train only) | | `dropout` | `0.0` | use dropout regularization (classify train only) |
| `val` | `True` | validate/test during training | | `val` | `True` | validate/test during training |

@ -77,53 +77,53 @@ include:
The training settings for YOLO models encompass various hyperparameters and configurations used during the training process. These settings influence the model's performance, speed, and accuracy. Key training settings include batch size, learning rate, momentum, and weight decay. Additionally, the choice of optimizer, loss function, and training dataset composition can impact the training process. Careful tuning and experimentation with these settings are crucial for optimizing performance. The training settings for YOLO models encompass various hyperparameters and configurations used during the training process. These settings influence the model's performance, speed, and accuracy. Key training settings include batch size, learning rate, momentum, and weight decay. Additionally, the choice of optimizer, loss function, and training dataset composition can impact the training process. Careful tuning and experimentation with these settings are crucial for optimizing performance.
| Key | Value | Description | | Key | Value | Description |
|-------------------|----------|-----------------------------------------------------------------------------| |-------------------|----------|-----------------------------------------------------------------------------------|
| `model` | `None` | path to model file, i.e. yolov8n.pt, yolov8n.yaml | | `model` | `None` | path to model file, i.e. yolov8n.pt, yolov8n.yaml |
| `data` | `None` | path to data file, i.e. coco128.yaml | | `data` | `None` | path to data file, i.e. coco128.yaml |
| `epochs` | `100` | number of epochs to train for | | `epochs` | `100` | number of epochs to train for |
| `patience` | `50` | epochs to wait for no observable improvement for early stopping of training | | `patience` | `50` | epochs to wait for no observable improvement for early stopping of training |
| `batch` | `16` | number of images per batch (-1 for AutoBatch) | | `batch` | `16` | number of images per batch (-1 for AutoBatch) |
| `imgsz` | `640` | size of input images as integer or w,h | | `imgsz` | `640` | size of input images as integer or w,h |
| `save` | `True` | save train checkpoints and predict results | | `save` | `True` | save train checkpoints and predict results |
| `save_period` | `-1` | Save checkpoint every x epochs (disabled if < 1) | | `save_period` | `-1` | Save checkpoint every x epochs (disabled if < 1) |
| `cache` | `False` | True/ram, disk or False. Use cache for data loading | | `cache` | `False` | True/ram, disk or False. Use cache for data loading |
| `device` | `None` | device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu | | `device` | `None` | device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu |
| `workers` | `8` | number of worker threads for data loading (per RANK if DDP) | | `workers` | `8` | number of worker threads for data loading (per RANK if DDP) |
| `project` | `None` | project name | | `project` | `None` | project name |
| `name` | `None` | experiment name | | `name` | `None` | experiment name |
| `exist_ok` | `False` | whether to overwrite existing experiment | | `exist_ok` | `False` | whether to overwrite existing experiment |
| `pretrained` | `False` | whether to use a pretrained model | | `pretrained` | `False` | whether to use a pretrained model |
| `optimizer` | `'SGD'` | optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp'] | | `optimizer` | `'auto'` | optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto] |
| `verbose` | `False` | whether to print verbose output | | `verbose` | `False` | whether to print verbose output |
| `seed` | `0` | random seed for reproducibility | | `seed` | `0` | random seed for reproducibility |
| `deterministic` | `True` | whether to enable deterministic mode | | `deterministic` | `True` | whether to enable deterministic mode |
| `single_cls` | `False` | train multi-class data as single-class | | `single_cls` | `False` | train multi-class data as single-class |
| `rect` | `False` | rectangular training with each batch collated for minimum padding | | `rect` | `False` | rectangular training with each batch collated for minimum padding |
| `cos_lr` | `False` | use cosine learning rate scheduler | | `cos_lr` | `False` | use cosine learning rate scheduler |
| `close_mosaic` | `0` | (int) disable mosaic augmentation for final epochs | | `close_mosaic` | `0` | (int) disable mosaic augmentation for final epochs |
| `resume` | `False` | resume training from last checkpoint | | `resume` | `False` | resume training from last checkpoint |
| `amp` | `True` | Automatic Mixed Precision (AMP) training, choices=[True, False] | | `amp` | `True` | Automatic Mixed Precision (AMP) training, choices=[True, False] |
| `fraction` | `1.0` | dataset fraction to train on (default is 1.0, all images in train set) | | `fraction` | `1.0` | dataset fraction to train on (default is 1.0, all images in train set) |
| `profile` | `False` | profile ONNX and TensorRT speeds during training for loggers | | `profile` | `False` | profile ONNX and TensorRT speeds during training for loggers |
| `lr0` | `0.01` | initial learning rate (i.e. SGD=1E-2, Adam=1E-3) | | `lr0` | `0.01` | initial learning rate (i.e. SGD=1E-2, Adam=1E-3) |
| `lrf` | `0.01` | final learning rate (lr0 * lrf) | | `lrf` | `0.01` | final learning rate (lr0 * lrf) |
| `momentum` | `0.937` | SGD momentum/Adam beta1 | | `momentum` | `0.937` | SGD momentum/Adam beta1 |
| `weight_decay` | `0.0005` | optimizer weight decay 5e-4 | | `weight_decay` | `0.0005` | optimizer weight decay 5e-4 |
| `warmup_epochs` | `3.0` | warmup epochs (fractions ok) | | `warmup_epochs` | `3.0` | warmup epochs (fractions ok) |
| `warmup_momentum` | `0.8` | warmup initial momentum | | `warmup_momentum` | `0.8` | warmup initial momentum |
| `warmup_bias_lr` | `0.1` | warmup initial bias lr | | `warmup_bias_lr` | `0.1` | warmup initial bias lr |
| `box` | `7.5` | box loss gain | | `box` | `7.5` | box loss gain |
| `cls` | `0.5` | cls loss gain (scale with pixels) | | `cls` | `0.5` | cls loss gain (scale with pixels) |
| `dfl` | `1.5` | dfl loss gain | | `dfl` | `1.5` | dfl loss gain |
| `pose` | `12.0` | pose loss gain (pose-only) | | `pose` | `12.0` | pose loss gain (pose-only) |
| `kobj` | `2.0` | keypoint obj loss gain (pose-only) | | `kobj` | `2.0` | keypoint obj loss gain (pose-only) |
| `label_smoothing` | `0.0` | label smoothing (fraction) | | `label_smoothing` | `0.0` | label smoothing (fraction) |
| `nbs` | `64` | nominal batch size | | `nbs` | `64` | nominal batch size |
| `overlap_mask` | `True` | masks should overlap during training (segment train only) | | `overlap_mask` | `True` | masks should overlap during training (segment train only) |
| `mask_ratio` | `4` | mask downsample ratio (segment train only) | | `mask_ratio` | `4` | mask downsample ratio (segment train only) |
| `dropout` | `0.0` | use dropout regularization (classify train only) | | `dropout` | `0.0` | use dropout regularization (classify train only) |
| `val` | `True` | validate/test during training | | `val` | `True` | validate/test during training |
[Train Guide](../modes/train.md){ .md-button .md-button--primary} [Train Guide](../modes/train.md){ .md-button .md-button--primary}

@ -20,7 +20,7 @@ project: # project name
name: # experiment name, results saved to 'project/name' directory name: # experiment name, results saved to 'project/name' directory
exist_ok: False # whether to overwrite existing experiment exist_ok: False # whether to overwrite existing experiment
pretrained: False # whether to use a pretrained model pretrained: False # whether to use a pretrained model
optimizer: SGD # optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp'] optimizer: auto # optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto]
verbose: True # whether to print verbose output verbose: True # whether to print verbose output
seed: 0 # random seed for reproducibility seed: 0 # random seed for reproducibility
deterministic: True # whether to enable deterministic mode deterministic: True # whether to enable deterministic mode

@ -5,6 +5,7 @@ Train a model on a dataset
Usage: Usage:
$ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16 $ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16
""" """
import math
import os import os
import subprocess import subprocess
import time import time
@ -14,11 +15,10 @@ from pathlib import Path
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist from torch import distributed as dist
import torch.nn as nn from torch import nn, optim
from torch.cuda import amp from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import lr_scheduler
from tqdm import tqdm from tqdm import tqdm
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
@ -234,33 +234,35 @@ class BaseTrainer:
SyntaxError('batch=-1 to use AutoBatch is only available in Single-GPU training. ' SyntaxError('batch=-1 to use AutoBatch is only available in Single-GPU training. '
'Please pass a valid batch size value for Multi-GPU DDP training, i.e. batch=16') 'Please pass a valid batch size value for Multi-GPU DDP training, i.e. batch=16')
# Dataloaders
batch_size = self.batch_size // max(world_size, 1)
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
if RANK in (-1, 0):
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
self.validator = self.get_validator()
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
self.ema = ModelEMA(self.model)
if self.args.plots and not self.args.v5loader:
self.plot_training_labels()
# Optimizer # Optimizer
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
self.optimizer = self.build_optimizer(model=self.model, self.optimizer = self.build_optimizer(model=self.model,
name=self.args.optimizer, name=self.args.optimizer,
lr=self.args.lr0, lr=self.args.lr0,
momentum=self.args.momentum, momentum=self.args.momentum,
decay=weight_decay) decay=weight_decay,
iterations=iterations)
# Scheduler # Scheduler
if self.args.cos_lr: if self.args.cos_lr:
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf'] self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
else: else:
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf) self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
# Dataloaders
batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
if RANK in (-1, 0):
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
self.validator = self.get_validator()
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
self.ema = ModelEMA(self.model)
if self.args.plots and not self.args.v5loader:
self.plot_training_labels()
self.resume_training(ckpt) self.resume_training(ckpt)
self.scheduler.last_epoch = self.start_epoch - 1 # do not move self.scheduler.last_epoch = self.start_epoch - 1 # do not move
self.run_callbacks('on_pretrain_routine_end') self.run_callbacks('on_pretrain_routine_end')
@ -603,24 +605,30 @@ class BaseTrainer:
if hasattr(self.train_loader.dataset, 'close_mosaic'): if hasattr(self.train_loader.dataset, 'close_mosaic'):
self.train_loader.dataset.close_mosaic(hyp=self.args) self.train_loader.dataset.close_mosaic(hyp=self.args)
@staticmethod def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
""" """
Builds an optimizer with the specified parameters and parameter groups. Constructs an optimizer for the given model, based on the specified optimizer name, learning rate,
momentum, weight decay, and number of iterations.
Args: Args:
model (nn.Module): model to optimize model (torch.nn.Module): The model for which to build an optimizer.
name (str): name of the optimizer to use name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
lr (float): learning rate based on the number of iterations. Default: 'auto'.
momentum (float): momentum lr (float, optional): The learning rate for the optimizer. Default: 0.001.
decay (float): weight decay momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
iterations (float, optional): The number of iterations, which determines the optimizer if
name is 'auto'. Default: 1e5.
Returns: Returns:
optimizer (torch.optim.Optimizer): the built optimizer (torch.optim.Optimizer): The constructed optimizer.
""" """
g = [], [], [] # optimizer parameter groups g = [], [], [] # optimizer parameter groups
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d() bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
if name == 'auto':
name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 6000 else ('NAdam', 0.001, 0.9)
self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for NAdam
for module_name, module in model.named_modules(): for module_name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False): for param_name, param in module.named_parameters(recurse=False):
@ -632,19 +640,21 @@ class BaseTrainer:
else: # weight (with decay) else: # weight (with decay)
g[0].append(param) g[0].append(param)
if name == 'Adam': if name in ('Adam', 'Adamax', 'AdamW', 'NAdam', 'RAdam'):
optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999)) # adjust beta1 to momentum optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
elif name == 'AdamW':
optimizer = torch.optim.AdamW(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
elif name == 'RMSProp': elif name == 'RMSProp':
optimizer = torch.optim.RMSprop(g[2], lr=lr, momentum=momentum) optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
elif name == 'SGD': elif name == 'SGD':
optimizer = torch.optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True) optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
else: else:
raise NotImplementedError(f'Optimizer {name} not implemented.') raise NotImplementedError(
f"Optimizer '{name}' not found in list of available optimizers "
f'[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto].'
'To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics.')
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights) optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups " LOGGER.info(
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias') f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)')
return optimizer return optimizer

@ -14,7 +14,7 @@ except ImportError:
tune = None tune = None
default_space = { default_space = {
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'RMSProp']), # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
'lr0': tune.uniform(1e-5, 1e-1), 'lr0': tune.uniform(1e-5, 1e-1),
'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) 'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1 'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1

Loading…
Cancel
Save