ultralytics 8.0.44 export and task fixes (#1088)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-02-24 03:11:25 +01:00
committed by GitHub
parent fe61018975
commit 3ea659411b
32 changed files with 439 additions and 480 deletions

View File

@ -208,12 +208,15 @@ class Exporter:
self.file = file
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
description = f'Ultralytics {self.pretty_name} model' + f'trained on {Path(self.args.data).name}' \
if self.args.data else '(untrained)'
self.metadata = {
'description': f'Ultralytics {self.pretty_name} model trained on {Path(self.args.data).name}',
'description': description,
'author': 'Ultralytics',
'license': 'GPL-3.0 https://ultralytics.com/license',
'version': __version__,
'stride': int(max(model.stride)),
'task': model.task,
'names': model.names} # model metadata
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and "

View File

@ -9,76 +9,72 @@ from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, Segmentat
guess_model_task, nn)
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, yaml_load
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, callbacks, yaml_load
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_yaml
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
# Map head to model, trainer, validator, and predictor classes
MODEL_MAP = {
TASK_MAP = {
'classify': [
ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator',
'yolo.TYPE.classify.ClassificationPredictor'],
ClassificationModel, yolo.v8.classify.ClassificationTrainer, yolo.v8.classify.ClassificationValidator,
yolo.v8.classify.ClassificationPredictor],
'detect': [
DetectionModel, 'yolo.TYPE.detect.DetectionTrainer', 'yolo.TYPE.detect.DetectionValidator',
'yolo.TYPE.detect.DetectionPredictor'],
DetectionModel, yolo.v8.detect.DetectionTrainer, yolo.v8.detect.DetectionValidator,
yolo.v8.detect.DetectionPredictor],
'segment': [
SegmentationModel, 'yolo.TYPE.segment.SegmentationTrainer', 'yolo.TYPE.segment.SegmentationValidator',
'yolo.TYPE.segment.SegmentationPredictor']}
SegmentationModel, yolo.v8.segment.SegmentationTrainer, yolo.v8.segment.SegmentationValidator,
yolo.v8.segment.SegmentationPredictor]}
class YOLO:
"""
YOLO (You Only Look Once) object detection model.
YOLO (You Only Look Once) object detection model.
Args:
model (str, Path): Path to the model file to load or create.
type (str): Type/version of models to use. Defaults to "v8".
Args:
model (str, Path): Path to the model file to load or create.
Attributes:
type (str): Type/version of models being used.
ModelClass (Any): Model class.
TrainerClass (Any): Trainer class.
ValidatorClass (Any): Validator class.
PredictorClass (Any): Predictor class.
predictor (Any): Predictor object.
model (Any): Model object.
trainer (Any): Trainer object.
task (str): Type of model task.
ckpt (Any): Checkpoint object if model loaded from *.pt file.
cfg (str): Model configuration if loaded from *.yaml file.
ckpt_path (str): Checkpoint file path.
overrides (dict): Overrides for trainer object.
metrics_data (Any): Data for metrics.
Attributes:
predictor (Any): The predictor object.
model (Any): The model object.
trainer (Any): The trainer object.
task (str): The type of model task.
ckpt (Any): The checkpoint object if the model loaded from *.pt file.
cfg (str): The model configuration if loaded from *.yaml file.
ckpt_path (str): The checkpoint file path.
overrides (dict): Overrides for the trainer object.
metrics_data (Any): The data for metrics.
Methods:
__call__(): Alias for predict method.
_new(cfg, verbose=True): Initializes a new model and infers the task type from the model definitions.
_load(weights): Initializes a new model and infers the task type from the model head.
_check_is_pytorch_model(): Raises TypeError if model is not a PyTorch model.
reset(): Resets the model modules.
info(verbose=False): Logs model info.
fuse(): Fuse model for faster inference.
predict(source=None, stream=False, **kwargs): Perform prediction using the YOLO model.
Methods:
__call__(source=None, stream=False, **kwargs):
Alias for the predict method.
_new(cfg:str, verbose:bool=True) -> None:
Initializes a new model and infers the task type from the model definitions.
_load(weights:str, task:str='') -> None:
Initializes a new model and infers the task type from the model head.
_check_is_pytorch_model() -> None:
Raises TypeError if the model is not a PyTorch model.
reset() -> None:
Resets the model modules.
info(verbose:bool=False) -> None:
Logs the model info.
fuse() -> None:
Fuses the model for faster inference.
predict(source=None, stream=False, **kwargs) -> List[ultralytics.yolo.engine.results.Results]:
Performs prediction using the YOLO model.
Returns:
list(ultralytics.yolo.engine.results.Results): The prediction results.
"""
Returns:
list[ultralytics.yolo.engine.results.Results]: The prediction results.
"""
def __init__(self, model='yolov8n.pt', type='v8') -> None:
def __init__(self, model='yolov8n.pt') -> None:
"""
Initializes the YOLO model.
Args:
model (str, Path): model to load or create
type (str): Type/version of models to use. Defaults to "v8".
"""
self._reset_callbacks()
self.type = type
self.ModelClass = None # model class
self.TrainerClass = None # trainer class
self.ValidatorClass = None # validator class
self.PredictorClass = None # predictor class
self.predictor = None # reuse predictor
self.model = None # model object
self.trainer = None # trainer object
@ -101,6 +97,10 @@ class YOLO:
def __call__(self, source=None, stream=False, **kwargs):
return self.predict(source, stream, **kwargs)
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def _new(self, cfg: str, verbose=True):
"""
Initializes a new model and infers the task type from the model definitions.
@ -112,11 +112,15 @@ class YOLO:
self.cfg = check_yaml(cfg) # check YAML
cfg_dict = yaml_load(self.cfg, append_filename=True) # model dict
self.task = guess_model_task(cfg_dict)
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
self.model = self.ModelClass(cfg_dict, verbose=verbose and RANK == -1) # initialize
self.model = TASK_MAP[self.task][0](cfg_dict, verbose=verbose and RANK == -1) # build model
self.overrides['model'] = self.cfg
def _load(self, weights: str):
# Below added to allow export from yamls
args = {**DEFAULT_CFG_DICT, **self.overrides} # combine model and default args, preferring model args
self.model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
self.model.task = self.task
def _load(self, weights: str, task=''):
"""
Initializes a new model and infers the task type from the model head.
@ -127,8 +131,7 @@ class YOLO:
if suffix == '.pt':
self.model, self.ckpt = attempt_load_one_weight(weights)
self.task = self.model.args['task']
self.overrides = self.model.args
self._reset_ckpt_args(self.overrides)
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
self.ckpt_path = self.model.pt_path
else:
weights = check_file(weights)
@ -136,7 +139,6 @@ class YOLO:
self.task = guess_model_task(weights)
self.ckpt_path = weights
self.overrides['model'] = weights
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
def _check_is_pytorch_model(self):
"""
@ -189,12 +191,13 @@ class YOLO:
"""
overrides = self.overrides.copy()
overrides['conf'] = 0.25
overrides.update(kwargs)
overrides.update(kwargs) # prefer kwargs
overrides['mode'] = kwargs.get('mode', 'predict')
assert overrides['mode'] in ['track', 'predict']
overrides['save'] = kwargs.get('save', False) # not save files by default
if not self.predictor:
self.predictor = self.PredictorClass(overrides=overrides)
self.task = overrides.get('task') or self.task
self.predictor = TASK_MAP[self.task][3](overrides=overrides)
self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
@ -226,12 +229,15 @@ class YOLO:
overrides['mode'] = 'val'
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.data = data or args.data
args.task = self.task
if 'task' in overrides:
self.task = args.task
else:
args.task = self.task
if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)):
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
validator = self.ValidatorClass(args=args)
validator = TASK_MAP[self.task][2](args=args)
validator(model=self.model)
self.metrics_data = validator.metrics
@ -267,8 +273,7 @@ class YOLO:
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
if args.batch == DEFAULT_CFG.batch:
args.batch = 1 # default to 1 if not modified
exporter = Exporter(overrides=args)
return exporter(model=self.model)
return Exporter(overrides=args)(model=self.model)
def train(self, **kwargs):
"""
@ -282,15 +287,15 @@ class YOLO:
overrides.update(kwargs)
if kwargs.get('cfg'):
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
overrides = yaml_load(check_yaml(kwargs['cfg']), append_filename=True)
overrides['task'] = self.task
overrides = yaml_load(check_yaml(kwargs['cfg']))
overrides['mode'] = 'train'
if not overrides.get('data'):
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
if overrides.get('resume'):
overrides['resume'] = self.ckpt_path
self.trainer = self.TrainerClass(overrides=overrides)
self.task = overrides.get('task') or self.task
self.trainer = TASK_MAP[self.task][1](overrides=overrides)
if not overrides.get('resume'): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
@ -311,13 +316,6 @@ class YOLO:
self._check_is_pytorch_model()
self.model.to(device)
def _assign_ops_from_task(self):
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[self.task]
trainer_class = eval(train_lit.replace('TYPE', f'{self.type}'))
validator_class = eval(val_lit.replace('TYPE', f'{self.type}'))
predictor_class = eval(pred_lit.replace('TYPE', f'{self.type}'))
return model_class, trainer_class, validator_class, predictor_class
@property
def names(self):
"""
@ -357,9 +355,8 @@ class YOLO:
@staticmethod
def _reset_ckpt_args(args):
for arg in 'augment', 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', \
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots', 'opset', 'simplify':
args.pop(arg, None)
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
return {k: v for k, v in args.items() if k in include}
@staticmethod
def _reset_callbacks():

View File

@ -108,7 +108,6 @@ class BasePredictor:
def postprocess(self, preds, img, orig_img):
return preds
@smart_inference_mode()
def __call__(self, source=None, model=None, stream=False):
if stream:
return self.stream_inference(source, model)
@ -136,6 +135,7 @@ class BasePredictor:
self.source_type = self.dataset.source_type
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
@smart_inference_mode()
def stream_inference(self, source=None, model=None):
if self.args.verbose:
LOGGER.info('')
@ -161,12 +161,14 @@ class BasePredictor:
self.batch = batch
path, im, im0s, vid_cap, s = batch
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
# preprocess
with self.dt[0]:
im = self.preprocess(im)
if len(im.shape) == 3:
im = im[None] # expand for batch dim
# Inference
# inference
with self.dt[1]:
preds = self.model(im, augment=self.args.augment, visualize=visualize)

View File

@ -18,29 +18,33 @@ from ultralytics.yolo.utils.plotting import Annotator, colors
class Results:
"""
A class for storing and manipulating inference results.
A class for storing and manipulating inference results.
Args:
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
masks (Masks, optional): A Masks object containing the detection masks.
probs (torch.Tensor, optional): A tensor containing the detection class probabilities.
orig_img (tuple, optional): Original image size.
Args:
orig_img (numpy.ndarray): The original image as a numpy array.
path (str): The path to the image file.
names (List[str]): A list of class names.
boxes (List[List[float]], optional): A list of bounding box coordinates for each detection.
masks (numpy.ndarray, optional): A 3D numpy array of detection masks, where each mask is a binary image.
probs (numpy.ndarray, optional): A 2D numpy array of detection probabilities for each class.
Attributes:
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
masks (Masks, optional): A Masks object containing the detection masks.
probs (torch.Tensor, optional): A tensor containing the detection class probabilities.
orig_img (tuple, optional): Original image size.
data (torch.Tensor): The raw masks tensor
"""
Attributes:
orig_img (numpy.ndarray): The original image as a numpy array.
orig_shape (tuple): The original image shape in (height, width) format.
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
masks (Masks, optional): A Masks object containing the detection masks.
probs (numpy.ndarray, optional): A 2D numpy array of detection probabilities for each class.
names (List[str]): A list of class names.
path (str): The path to the image file.
_keys (tuple): A tuple of attribute names for non-empty attributes.
"""
def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None) -> None:
self.orig_img = orig_img
self.orig_shape = orig_img.shape[:2]
self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
self.probs = probs if probs is not None else None
self.boxes = Boxes(boxes.cpu(), self.orig_shape) if boxes is not None else None # native size boxes
self.masks = Masks(masks.cpu(), self.orig_shape) if masks is not None else None # native size or imgsz masks
self.probs = probs.cpu() if probs is not None else None
self.names = names
self.path = path
self._keys = (k for k in ('boxes', 'masks', 'probs') if getattr(self, k) is not None)
@ -99,24 +103,22 @@ class Results:
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"""
'{name}' object has no attribute '{attr}'. Valid '{name}' object attributes and properties are:
Attributes:
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
masks (Masks, optional): A Masks object containing the detection masks.
probs (torch.Tensor, optional): A tensor containing the detection class probabilities.
orig_shape (tuple, optional): Original image size.
""")
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def plot(self, show_conf=True, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
"""
Plots the given result on an input RGB image. Accepts cv2(numpy) or PIL Image
Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
Args:
show_conf (bool): Show confidence
line_width (Float): The line width of boxes. Automatically scaled to img size if not provided
font_size (Float): The font size of . Automatically scaled to img size if not provided
show_conf (bool): Whether to show the detection confidence score.
line_width (float, optional): The line width of the bounding boxes. If None, it is automatically scaled to the image size.
font_size (float, optional): The font size of the text. If None, it is automatically scaled to the image size.
font (str): The font to use for the text.
pil (bool): Whether to return the image as a PIL Image.
example (str): An example string to display in the plot. Useful for indicating the expected format of the output.
Returns:
None or PIL Image: If `pil` is True, the image will be returned as a PIL Image. Otherwise, nothing is returned.
"""
img = deepcopy(self.orig_img)
annotator = Annotator(img, line_width, font_size, font, pil, example)
@ -157,15 +159,24 @@ class Boxes:
boxes (torch.Tensor) or (numpy.ndarray): A tensor or numpy array containing the detection boxes,
with shape (num_boxes, 6).
orig_shape (torch.Tensor) or (numpy.ndarray): Original image size, in the format (height, width).
is_track (bool): True if the boxes also include track IDs, False otherwise.
Properties:
xyxy (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format.
conf (torch.Tensor) or (numpy.ndarray): The confidence values of the boxes.
cls (torch.Tensor) or (numpy.ndarray): The class values of the boxes.
id (torch.Tensor) or (numpy.ndarray): The track IDs of the boxes (if available).
xywh (torch.Tensor) or (numpy.ndarray): The boxes in xywh format.
xyxyn (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format normalized by original image size.
xywhn (torch.Tensor) or (numpy.ndarray): The boxes in xywh format normalized by original image size.
data (torch.Tensor): The raw bboxes tensor
Methods:
cpu(): Move the object to CPU memory.
numpy(): Convert the object to a numpy array.
cuda(): Move the object to CUDA memory.
to(*args, **kwargs): Move the object to the specified device.
pandas(): Convert the object to a pandas DataFrame (not yet implemented).
"""
def __init__(self, boxes, orig_shape) -> None:
@ -257,22 +268,7 @@ class Boxes:
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"""
'{name}' object has no attribute '{attr}'. Valid '{name}' object attributes and properties are:
Attributes:
boxes (torch.Tensor) or (numpy.ndarray): A tensor or numpy array containing the detection boxes,
with shape (num_boxes, 6).
orig_shape (torch.Tensor) or (numpy.ndarray): Original image size, in the format (height, width).
Properties:
xyxy (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format.
conf (torch.Tensor) or (numpy.ndarray): The confidence values of the boxes.
cls (torch.Tensor) or (numpy.ndarray): The class values of the boxes.
xywh (torch.Tensor) or (numpy.ndarray): The boxes in xywh format.
xyxyn (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format normalized by original image size.
xywhn (torch.Tensor) or (numpy.ndarray): The boxes in xywh format normalized by original image size.
""")
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
class Masks:
@ -288,7 +284,18 @@ class Masks:
orig_shape (tuple): Original image size, in the format (height, width).
Properties:
segments (list): A list of segments which includes x,y,w,h,label,confidence, and mask of each detection masks.
segments (list): A list of segments which includes x, y, w, h, label, confidence, and mask of each detection masks.
Methods:
cpu(): Returns a copy of the masks tensor on CPU memory.
numpy(): Returns a copy of the masks tensor as a numpy array.
cuda(): Returns a copy of the masks tensor on GPU memory.
to(): Returns a copy of the masks tensor with the specified device and dtype.
__len__(): Returns the number of masks in the tensor.
__str__(): Returns a string representation of the masks tensor.
__repr__(): Returns a detailed string representation of the masks tensor.
__getitem__(): Returns a new Masks object with the masks at the specified index.
__getattr__(): Raises an AttributeError with a list of valid attributes and properties.
"""
def __init__(self, masks, orig_shape) -> None:
@ -337,13 +344,4 @@ class Masks:
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"""
'{name}' object has no attribute '{attr}'. Valid '{name}' object attributes and properties are:
Attributes:
masks (torch.Tensor): A tensor containing the detection masks, with shape (num_masks, height, width).
orig_shape (tuple): Original image size, in the format (height, width).
Properties:
segments (list): A list of segments which includes x,y,w,h,label,confidence, and mask of each detection masks.
""")
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")

View File

@ -44,7 +44,6 @@ class BaseTrainer:
Attributes:
args (SimpleNamespace): Configuration for the trainer.
check_resume (method): Method to check if training should be resumed from a saved checkpoint.
console (logging.Logger): Logger instance.
validator (BaseValidator): Validator instance.
model (nn.Module): Model instance.
callbacks (defaultdict): Dictionary of callbacks.
@ -84,7 +83,6 @@ class BaseTrainer:
self.args = get_cfg(cfg, overrides)
self.device = select_device(self.args.device, self.args.batch)
self.check_resume()
self.console = LOGGER
self.validator = None
self.model = None
self.metrics = None
@ -180,11 +178,12 @@ class BaseTrainer:
if world_size > 1 and 'LOCAL_RANK' not in os.environ:
cmd, file = generate_ddp_command(world_size, self) # security vulnerability in Snyk scans
try:
LOGGER.info(f'Running DDP command {cmd}')
subprocess.run(cmd, check=True)
except Exception as e:
self.console.warning(e)
LOGGER.warning(e)
finally:
ddp_cleanup(self, file)
ddp_cleanup(self, str(file))
else:
self._do_train(RANK, world_size)
@ -193,7 +192,7 @@ class BaseTrainer:
# os.environ['MASTER_PORT'] = '9020'
torch.cuda.set_device(rank)
self.device = torch.device('cuda', rank)
self.console.info(f'DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}')
LOGGER.info(f'DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}')
dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo', rank=rank, world_size=world_size)
def _setup_train(self, rank, world_size):
@ -262,10 +261,10 @@ class BaseTrainer:
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
last_opt_step = -1
self.run_callbacks('on_train_start')
self.log(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
f"Logging results to {colorstr('bold', self.save_dir)}\n"
f'Starting training for {self.epochs} epochs...')
LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
f"Logging results to {colorstr('bold', self.save_dir)}\n"
f'Starting training for {self.epochs} epochs...')
if self.args.close_mosaic:
base_idx = (self.epochs - self.args.close_mosaic) * nb
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
@ -278,14 +277,14 @@ class BaseTrainer:
pbar = enumerate(self.train_loader)
# Update dataloader attributes (optional)
if epoch == (self.epochs - self.args.close_mosaic):
self.console.info('Closing dataloader mosaic')
LOGGER.info('Closing dataloader mosaic')
if hasattr(self.train_loader.dataset, 'mosaic'):
self.train_loader.dataset.mosaic = False
if hasattr(self.train_loader.dataset, 'close_mosaic'):
self.train_loader.dataset.close_mosaic(hyp=self.args)
if rank in {-1, 0}:
self.console.info(self.progress_string())
LOGGER.info(self.progress_string())
pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT)
self.tloss = None
self.optimizer.zero_grad()
@ -372,12 +371,11 @@ class BaseTrainer:
if rank in {-1, 0}:
# Do final val with best.pt
self.log(f'\n{epoch - self.start_epoch + 1} epochs completed in '
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in '
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
self.final_eval()
if self.args.plots:
self.plot_metrics()
self.log(f"Results saved to {colorstr('bold', self.save_dir)}")
self.run_callbacks('on_train_end')
torch.cuda.empty_cache()
self.run_callbacks('teardown')
@ -450,18 +448,6 @@ class BaseTrainer:
self.best_fitness = fitness
return metrics, fitness
def log(self, text, rank=-1):
"""
Logs the given text to given ranks process if provided, otherwise logs to all ranks.
Args"
text (str): text to log
rank (List[Int]): process rank
"""
if rank in {-1, 0}:
self.console.info(text)
def get_model(self, cfg=None, weights=None, verbose=True):
raise NotImplementedError("This task trainer doesn't support loading cfg files")
@ -521,7 +507,7 @@ class BaseTrainer:
if f.exists():
strip_optimizer(f) # strip optimizers
if f is self.best:
self.console.info(f'\nValidating {f}...')
LOGGER.info(f'\nValidating {f}...')
self.metrics = self.validator(model=f)
self.metrics.pop('fitness', None)
self.run_callbacks('on_fit_epoch_end')
@ -564,7 +550,7 @@ class BaseTrainer:
self.best_fitness = best_fitness
self.start_epoch = start_epoch
if start_epoch > (self.epochs - self.args.close_mosaic):
self.console.info('Closing dataloader mosaic')
LOGGER.info('Closing dataloader mosaic')
if hasattr(self.train_loader.dataset, 'mosaic'):
self.train_loader.dataset.mosaic = False
if hasattr(self.train_loader.dataset, 'close_mosaic'):

View File

@ -44,7 +44,6 @@ class BaseValidator:
Attributes:
dataloader (DataLoader): Dataloader to use for validation.
pbar (tqdm): Progress bar to update during validation.
logger (logging.Logger): Logger to use for validation.
args (SimpleNamespace): Configuration for the validator.
model (nn.Module): Model to validate.
data (dict): Data dictionary.
@ -56,7 +55,7 @@ class BaseValidator:
save_dir (Path): Directory to save results.
"""
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None):
"""
Initializes a BaseValidator instance.
@ -69,14 +68,13 @@ class BaseValidator:
"""
self.dataloader = dataloader
self.pbar = pbar
self.logger = logger or LOGGER
self.args = args or get_cfg(DEFAULT_CFG)
self.model = None
self.data = None
self.device = None
self.batch_i = None
self.training = True
self.speed = None
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
self.jdict = None
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
@ -123,14 +121,14 @@ class BaseValidator:
self.device = model.device
if not pt and not jit:
self.args.batch = 1 # export.py models default to batch-size 1
self.logger.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'):
self.data = check_det_dataset(self.args.data)
elif self.args.task == 'classify':
self.data = check_cls_dataset(self.args.data)
else:
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' not found ❌"))
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
if self.device.type == 'cpu':
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
@ -179,7 +177,7 @@ class BaseValidator:
stats = self.get_stats()
self.check_stats(stats)
self.print_results()
self.speed = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1E3 for x in dt)))
self.finalize_metrics()
self.run_callbacks('on_val_end')
if self.training:
@ -187,11 +185,11 @@ class BaseValidator:
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
else:
self.logger.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
self.speed)
LOGGER.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
tuple(self.speed.values()))
if self.args.save_json and self.jdict:
with open(str(self.save_dir / 'predictions.json'), 'w') as f:
self.logger.info(f'Saving {f.name}...')
LOGGER.info(f'Saving {f.name}...')
json.dump(self.jdict, f) # flatten and save
stats = self.eval_json(stats) # update stats
if self.args.plots or self.args.save_json: