ultralytics 8.0.44
export and task fixes (#1088)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
@ -208,12 +208,15 @@ class Exporter:
|
||||
self.file = file
|
||||
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
|
||||
self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
|
||||
description = f'Ultralytics {self.pretty_name} model' + f'trained on {Path(self.args.data).name}' \
|
||||
if self.args.data else '(untrained)'
|
||||
self.metadata = {
|
||||
'description': f'Ultralytics {self.pretty_name} model trained on {Path(self.args.data).name}',
|
||||
'description': description,
|
||||
'author': 'Ultralytics',
|
||||
'license': 'GPL-3.0 https://ultralytics.com/license',
|
||||
'version': __version__,
|
||||
'stride': int(max(model.stride)),
|
||||
'task': model.task,
|
||||
'names': model.names} # model metadata
|
||||
|
||||
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and "
|
||||
|
@ -9,76 +9,72 @@ from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, Segmentat
|
||||
guess_model_task, nn)
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, yaml_load
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, callbacks, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_yaml
|
||||
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
|
||||
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
||||
|
||||
# Map head to model, trainer, validator, and predictor classes
|
||||
MODEL_MAP = {
|
||||
TASK_MAP = {
|
||||
'classify': [
|
||||
ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator',
|
||||
'yolo.TYPE.classify.ClassificationPredictor'],
|
||||
ClassificationModel, yolo.v8.classify.ClassificationTrainer, yolo.v8.classify.ClassificationValidator,
|
||||
yolo.v8.classify.ClassificationPredictor],
|
||||
'detect': [
|
||||
DetectionModel, 'yolo.TYPE.detect.DetectionTrainer', 'yolo.TYPE.detect.DetectionValidator',
|
||||
'yolo.TYPE.detect.DetectionPredictor'],
|
||||
DetectionModel, yolo.v8.detect.DetectionTrainer, yolo.v8.detect.DetectionValidator,
|
||||
yolo.v8.detect.DetectionPredictor],
|
||||
'segment': [
|
||||
SegmentationModel, 'yolo.TYPE.segment.SegmentationTrainer', 'yolo.TYPE.segment.SegmentationValidator',
|
||||
'yolo.TYPE.segment.SegmentationPredictor']}
|
||||
SegmentationModel, yolo.v8.segment.SegmentationTrainer, yolo.v8.segment.SegmentationValidator,
|
||||
yolo.v8.segment.SegmentationPredictor]}
|
||||
|
||||
|
||||
class YOLO:
|
||||
"""
|
||||
YOLO (You Only Look Once) object detection model.
|
||||
YOLO (You Only Look Once) object detection model.
|
||||
|
||||
Args:
|
||||
model (str, Path): Path to the model file to load or create.
|
||||
type (str): Type/version of models to use. Defaults to "v8".
|
||||
Args:
|
||||
model (str, Path): Path to the model file to load or create.
|
||||
|
||||
Attributes:
|
||||
type (str): Type/version of models being used.
|
||||
ModelClass (Any): Model class.
|
||||
TrainerClass (Any): Trainer class.
|
||||
ValidatorClass (Any): Validator class.
|
||||
PredictorClass (Any): Predictor class.
|
||||
predictor (Any): Predictor object.
|
||||
model (Any): Model object.
|
||||
trainer (Any): Trainer object.
|
||||
task (str): Type of model task.
|
||||
ckpt (Any): Checkpoint object if model loaded from *.pt file.
|
||||
cfg (str): Model configuration if loaded from *.yaml file.
|
||||
ckpt_path (str): Checkpoint file path.
|
||||
overrides (dict): Overrides for trainer object.
|
||||
metrics_data (Any): Data for metrics.
|
||||
Attributes:
|
||||
predictor (Any): The predictor object.
|
||||
model (Any): The model object.
|
||||
trainer (Any): The trainer object.
|
||||
task (str): The type of model task.
|
||||
ckpt (Any): The checkpoint object if the model loaded from *.pt file.
|
||||
cfg (str): The model configuration if loaded from *.yaml file.
|
||||
ckpt_path (str): The checkpoint file path.
|
||||
overrides (dict): Overrides for the trainer object.
|
||||
metrics_data (Any): The data for metrics.
|
||||
|
||||
Methods:
|
||||
__call__(): Alias for predict method.
|
||||
_new(cfg, verbose=True): Initializes a new model and infers the task type from the model definitions.
|
||||
_load(weights): Initializes a new model and infers the task type from the model head.
|
||||
_check_is_pytorch_model(): Raises TypeError if model is not a PyTorch model.
|
||||
reset(): Resets the model modules.
|
||||
info(verbose=False): Logs model info.
|
||||
fuse(): Fuse model for faster inference.
|
||||
predict(source=None, stream=False, **kwargs): Perform prediction using the YOLO model.
|
||||
Methods:
|
||||
__call__(source=None, stream=False, **kwargs):
|
||||
Alias for the predict method.
|
||||
_new(cfg:str, verbose:bool=True) -> None:
|
||||
Initializes a new model and infers the task type from the model definitions.
|
||||
_load(weights:str, task:str='') -> None:
|
||||
Initializes a new model and infers the task type from the model head.
|
||||
_check_is_pytorch_model() -> None:
|
||||
Raises TypeError if the model is not a PyTorch model.
|
||||
reset() -> None:
|
||||
Resets the model modules.
|
||||
info(verbose:bool=False) -> None:
|
||||
Logs the model info.
|
||||
fuse() -> None:
|
||||
Fuses the model for faster inference.
|
||||
predict(source=None, stream=False, **kwargs) -> List[ultralytics.yolo.engine.results.Results]:
|
||||
Performs prediction using the YOLO model.
|
||||
|
||||
Returns:
|
||||
list(ultralytics.yolo.engine.results.Results): The prediction results.
|
||||
"""
|
||||
Returns:
|
||||
list[ultralytics.yolo.engine.results.Results]: The prediction results.
|
||||
"""
|
||||
|
||||
def __init__(self, model='yolov8n.pt', type='v8') -> None:
|
||||
def __init__(self, model='yolov8n.pt') -> None:
|
||||
"""
|
||||
Initializes the YOLO model.
|
||||
|
||||
Args:
|
||||
model (str, Path): model to load or create
|
||||
type (str): Type/version of models to use. Defaults to "v8".
|
||||
"""
|
||||
self._reset_callbacks()
|
||||
self.type = type
|
||||
self.ModelClass = None # model class
|
||||
self.TrainerClass = None # trainer class
|
||||
self.ValidatorClass = None # validator class
|
||||
self.PredictorClass = None # predictor class
|
||||
self.predictor = None # reuse predictor
|
||||
self.model = None # model object
|
||||
self.trainer = None # trainer object
|
||||
@ -101,6 +97,10 @@ class YOLO:
|
||||
def __call__(self, source=None, stream=False, **kwargs):
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
def _new(self, cfg: str, verbose=True):
|
||||
"""
|
||||
Initializes a new model and infers the task type from the model definitions.
|
||||
@ -112,11 +112,15 @@ class YOLO:
|
||||
self.cfg = check_yaml(cfg) # check YAML
|
||||
cfg_dict = yaml_load(self.cfg, append_filename=True) # model dict
|
||||
self.task = guess_model_task(cfg_dict)
|
||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
|
||||
self.model = self.ModelClass(cfg_dict, verbose=verbose and RANK == -1) # initialize
|
||||
self.model = TASK_MAP[self.task][0](cfg_dict, verbose=verbose and RANK == -1) # build model
|
||||
self.overrides['model'] = self.cfg
|
||||
|
||||
def _load(self, weights: str):
|
||||
# Below added to allow export from yamls
|
||||
args = {**DEFAULT_CFG_DICT, **self.overrides} # combine model and default args, preferring model args
|
||||
self.model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
||||
self.model.task = self.task
|
||||
|
||||
def _load(self, weights: str, task=''):
|
||||
"""
|
||||
Initializes a new model and infers the task type from the model head.
|
||||
|
||||
@ -127,8 +131,7 @@ class YOLO:
|
||||
if suffix == '.pt':
|
||||
self.model, self.ckpt = attempt_load_one_weight(weights)
|
||||
self.task = self.model.args['task']
|
||||
self.overrides = self.model.args
|
||||
self._reset_ckpt_args(self.overrides)
|
||||
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
|
||||
self.ckpt_path = self.model.pt_path
|
||||
else:
|
||||
weights = check_file(weights)
|
||||
@ -136,7 +139,6 @@ class YOLO:
|
||||
self.task = guess_model_task(weights)
|
||||
self.ckpt_path = weights
|
||||
self.overrides['model'] = weights
|
||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
|
||||
|
||||
def _check_is_pytorch_model(self):
|
||||
"""
|
||||
@ -189,12 +191,13 @@ class YOLO:
|
||||
"""
|
||||
overrides = self.overrides.copy()
|
||||
overrides['conf'] = 0.25
|
||||
overrides.update(kwargs)
|
||||
overrides.update(kwargs) # prefer kwargs
|
||||
overrides['mode'] = kwargs.get('mode', 'predict')
|
||||
assert overrides['mode'] in ['track', 'predict']
|
||||
overrides['save'] = kwargs.get('save', False) # not save files by default
|
||||
if not self.predictor:
|
||||
self.predictor = self.PredictorClass(overrides=overrides)
|
||||
self.task = overrides.get('task') or self.task
|
||||
self.predictor = TASK_MAP[self.task][3](overrides=overrides)
|
||||
self.predictor.setup_model(model=self.model)
|
||||
else: # only update args if predictor is already setup
|
||||
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
||||
@ -226,12 +229,15 @@ class YOLO:
|
||||
overrides['mode'] = 'val'
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.data = data or args.data
|
||||
args.task = self.task
|
||||
if 'task' in overrides:
|
||||
self.task = args.task
|
||||
else:
|
||||
args.task = self.task
|
||||
if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)):
|
||||
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
|
||||
|
||||
validator = self.ValidatorClass(args=args)
|
||||
validator = TASK_MAP[self.task][2](args=args)
|
||||
validator(model=self.model)
|
||||
self.metrics_data = validator.metrics
|
||||
|
||||
@ -267,8 +273,7 @@ class YOLO:
|
||||
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
if args.batch == DEFAULT_CFG.batch:
|
||||
args.batch = 1 # default to 1 if not modified
|
||||
exporter = Exporter(overrides=args)
|
||||
return exporter(model=self.model)
|
||||
return Exporter(overrides=args)(model=self.model)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
@ -282,15 +287,15 @@ class YOLO:
|
||||
overrides.update(kwargs)
|
||||
if kwargs.get('cfg'):
|
||||
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
|
||||
overrides = yaml_load(check_yaml(kwargs['cfg']), append_filename=True)
|
||||
overrides['task'] = self.task
|
||||
overrides = yaml_load(check_yaml(kwargs['cfg']))
|
||||
overrides['mode'] = 'train'
|
||||
if not overrides.get('data'):
|
||||
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
|
||||
if overrides.get('resume'):
|
||||
overrides['resume'] = self.ckpt_path
|
||||
|
||||
self.trainer = self.TrainerClass(overrides=overrides)
|
||||
self.task = overrides.get('task') or self.task
|
||||
self.trainer = TASK_MAP[self.task][1](overrides=overrides)
|
||||
if not overrides.get('resume'): # manually set model only if not resuming
|
||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||
self.model = self.trainer.model
|
||||
@ -311,13 +316,6 @@ class YOLO:
|
||||
self._check_is_pytorch_model()
|
||||
self.model.to(device)
|
||||
|
||||
def _assign_ops_from_task(self):
|
||||
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[self.task]
|
||||
trainer_class = eval(train_lit.replace('TYPE', f'{self.type}'))
|
||||
validator_class = eval(val_lit.replace('TYPE', f'{self.type}'))
|
||||
predictor_class = eval(pred_lit.replace('TYPE', f'{self.type}'))
|
||||
return model_class, trainer_class, validator_class, predictor_class
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
"""
|
||||
@ -357,9 +355,8 @@ class YOLO:
|
||||
|
||||
@staticmethod
|
||||
def _reset_ckpt_args(args):
|
||||
for arg in 'augment', 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', \
|
||||
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots', 'opset', 'simplify':
|
||||
args.pop(arg, None)
|
||||
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
|
||||
return {k: v for k, v in args.items() if k in include}
|
||||
|
||||
@staticmethod
|
||||
def _reset_callbacks():
|
||||
|
@ -108,7 +108,6 @@ class BasePredictor:
|
||||
def postprocess(self, preds, img, orig_img):
|
||||
return preds
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, source=None, model=None, stream=False):
|
||||
if stream:
|
||||
return self.stream_inference(source, model)
|
||||
@ -136,6 +135,7 @@ class BasePredictor:
|
||||
self.source_type = self.dataset.source_type
|
||||
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
|
||||
|
||||
@smart_inference_mode()
|
||||
def stream_inference(self, source=None, model=None):
|
||||
if self.args.verbose:
|
||||
LOGGER.info('')
|
||||
@ -161,12 +161,14 @@ class BasePredictor:
|
||||
self.batch = batch
|
||||
path, im, im0s, vid_cap, s = batch
|
||||
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
|
||||
|
||||
# preprocess
|
||||
with self.dt[0]:
|
||||
im = self.preprocess(im)
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
|
||||
# Inference
|
||||
# inference
|
||||
with self.dt[1]:
|
||||
preds = self.model(im, augment=self.args.augment, visualize=visualize)
|
||||
|
||||
|
@ -18,29 +18,33 @@ from ultralytics.yolo.utils.plotting import Annotator, colors
|
||||
|
||||
class Results:
|
||||
"""
|
||||
A class for storing and manipulating inference results.
|
||||
A class for storing and manipulating inference results.
|
||||
|
||||
Args:
|
||||
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
||||
masks (Masks, optional): A Masks object containing the detection masks.
|
||||
probs (torch.Tensor, optional): A tensor containing the detection class probabilities.
|
||||
orig_img (tuple, optional): Original image size.
|
||||
Args:
|
||||
orig_img (numpy.ndarray): The original image as a numpy array.
|
||||
path (str): The path to the image file.
|
||||
names (List[str]): A list of class names.
|
||||
boxes (List[List[float]], optional): A list of bounding box coordinates for each detection.
|
||||
masks (numpy.ndarray, optional): A 3D numpy array of detection masks, where each mask is a binary image.
|
||||
probs (numpy.ndarray, optional): A 2D numpy array of detection probabilities for each class.
|
||||
|
||||
Attributes:
|
||||
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
||||
masks (Masks, optional): A Masks object containing the detection masks.
|
||||
probs (torch.Tensor, optional): A tensor containing the detection class probabilities.
|
||||
orig_img (tuple, optional): Original image size.
|
||||
data (torch.Tensor): The raw masks tensor
|
||||
|
||||
"""
|
||||
Attributes:
|
||||
orig_img (numpy.ndarray): The original image as a numpy array.
|
||||
orig_shape (tuple): The original image shape in (height, width) format.
|
||||
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
||||
masks (Masks, optional): A Masks object containing the detection masks.
|
||||
probs (numpy.ndarray, optional): A 2D numpy array of detection probabilities for each class.
|
||||
names (List[str]): A list of class names.
|
||||
path (str): The path to the image file.
|
||||
_keys (tuple): A tuple of attribute names for non-empty attributes.
|
||||
"""
|
||||
|
||||
def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None) -> None:
|
||||
self.orig_img = orig_img
|
||||
self.orig_shape = orig_img.shape[:2]
|
||||
self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes
|
||||
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
|
||||
self.probs = probs if probs is not None else None
|
||||
self.boxes = Boxes(boxes.cpu(), self.orig_shape) if boxes is not None else None # native size boxes
|
||||
self.masks = Masks(masks.cpu(), self.orig_shape) if masks is not None else None # native size or imgsz masks
|
||||
self.probs = probs.cpu() if probs is not None else None
|
||||
self.names = names
|
||||
self.path = path
|
||||
self._keys = (k for k in ('boxes', 'masks', 'probs') if getattr(self, k) is not None)
|
||||
@ -99,24 +103,22 @@ class Results:
|
||||
|
||||
def __getattr__(self, attr):
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"""
|
||||
'{name}' object has no attribute '{attr}'. Valid '{name}' object attributes and properties are:
|
||||
|
||||
Attributes:
|
||||
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
||||
masks (Masks, optional): A Masks object containing the detection masks.
|
||||
probs (torch.Tensor, optional): A tensor containing the detection class probabilities.
|
||||
orig_shape (tuple, optional): Original image size.
|
||||
""")
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
def plot(self, show_conf=True, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
||||
"""
|
||||
Plots the given result on an input RGB image. Accepts cv2(numpy) or PIL Image
|
||||
Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
|
||||
|
||||
Args:
|
||||
show_conf (bool): Show confidence
|
||||
line_width (Float): The line width of boxes. Automatically scaled to img size if not provided
|
||||
font_size (Float): The font size of . Automatically scaled to img size if not provided
|
||||
show_conf (bool): Whether to show the detection confidence score.
|
||||
line_width (float, optional): The line width of the bounding boxes. If None, it is automatically scaled to the image size.
|
||||
font_size (float, optional): The font size of the text. If None, it is automatically scaled to the image size.
|
||||
font (str): The font to use for the text.
|
||||
pil (bool): Whether to return the image as a PIL Image.
|
||||
example (str): An example string to display in the plot. Useful for indicating the expected format of the output.
|
||||
|
||||
Returns:
|
||||
None or PIL Image: If `pil` is True, the image will be returned as a PIL Image. Otherwise, nothing is returned.
|
||||
"""
|
||||
img = deepcopy(self.orig_img)
|
||||
annotator = Annotator(img, line_width, font_size, font, pil, example)
|
||||
@ -157,15 +159,24 @@ class Boxes:
|
||||
boxes (torch.Tensor) or (numpy.ndarray): A tensor or numpy array containing the detection boxes,
|
||||
with shape (num_boxes, 6).
|
||||
orig_shape (torch.Tensor) or (numpy.ndarray): Original image size, in the format (height, width).
|
||||
is_track (bool): True if the boxes also include track IDs, False otherwise.
|
||||
|
||||
Properties:
|
||||
xyxy (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format.
|
||||
conf (torch.Tensor) or (numpy.ndarray): The confidence values of the boxes.
|
||||
cls (torch.Tensor) or (numpy.ndarray): The class values of the boxes.
|
||||
id (torch.Tensor) or (numpy.ndarray): The track IDs of the boxes (if available).
|
||||
xywh (torch.Tensor) or (numpy.ndarray): The boxes in xywh format.
|
||||
xyxyn (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format normalized by original image size.
|
||||
xywhn (torch.Tensor) or (numpy.ndarray): The boxes in xywh format normalized by original image size.
|
||||
data (torch.Tensor): The raw bboxes tensor
|
||||
|
||||
Methods:
|
||||
cpu(): Move the object to CPU memory.
|
||||
numpy(): Convert the object to a numpy array.
|
||||
cuda(): Move the object to CUDA memory.
|
||||
to(*args, **kwargs): Move the object to the specified device.
|
||||
pandas(): Convert the object to a pandas DataFrame (not yet implemented).
|
||||
"""
|
||||
|
||||
def __init__(self, boxes, orig_shape) -> None:
|
||||
@ -257,22 +268,7 @@ class Boxes:
|
||||
|
||||
def __getattr__(self, attr):
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"""
|
||||
'{name}' object has no attribute '{attr}'. Valid '{name}' object attributes and properties are:
|
||||
|
||||
Attributes:
|
||||
boxes (torch.Tensor) or (numpy.ndarray): A tensor or numpy array containing the detection boxes,
|
||||
with shape (num_boxes, 6).
|
||||
orig_shape (torch.Tensor) or (numpy.ndarray): Original image size, in the format (height, width).
|
||||
|
||||
Properties:
|
||||
xyxy (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format.
|
||||
conf (torch.Tensor) or (numpy.ndarray): The confidence values of the boxes.
|
||||
cls (torch.Tensor) or (numpy.ndarray): The class values of the boxes.
|
||||
xywh (torch.Tensor) or (numpy.ndarray): The boxes in xywh format.
|
||||
xyxyn (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format normalized by original image size.
|
||||
xywhn (torch.Tensor) or (numpy.ndarray): The boxes in xywh format normalized by original image size.
|
||||
""")
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
|
||||
class Masks:
|
||||
@ -288,7 +284,18 @@ class Masks:
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Properties:
|
||||
segments (list): A list of segments which includes x,y,w,h,label,confidence, and mask of each detection masks.
|
||||
segments (list): A list of segments which includes x, y, w, h, label, confidence, and mask of each detection masks.
|
||||
|
||||
Methods:
|
||||
cpu(): Returns a copy of the masks tensor on CPU memory.
|
||||
numpy(): Returns a copy of the masks tensor as a numpy array.
|
||||
cuda(): Returns a copy of the masks tensor on GPU memory.
|
||||
to(): Returns a copy of the masks tensor with the specified device and dtype.
|
||||
__len__(): Returns the number of masks in the tensor.
|
||||
__str__(): Returns a string representation of the masks tensor.
|
||||
__repr__(): Returns a detailed string representation of the masks tensor.
|
||||
__getitem__(): Returns a new Masks object with the masks at the specified index.
|
||||
__getattr__(): Raises an AttributeError with a list of valid attributes and properties.
|
||||
"""
|
||||
|
||||
def __init__(self, masks, orig_shape) -> None:
|
||||
@ -337,13 +344,4 @@ class Masks:
|
||||
|
||||
def __getattr__(self, attr):
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"""
|
||||
'{name}' object has no attribute '{attr}'. Valid '{name}' object attributes and properties are:
|
||||
|
||||
Attributes:
|
||||
masks (torch.Tensor): A tensor containing the detection masks, with shape (num_masks, height, width).
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Properties:
|
||||
segments (list): A list of segments which includes x,y,w,h,label,confidence, and mask of each detection masks.
|
||||
""")
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
@ -44,7 +44,6 @@ class BaseTrainer:
|
||||
Attributes:
|
||||
args (SimpleNamespace): Configuration for the trainer.
|
||||
check_resume (method): Method to check if training should be resumed from a saved checkpoint.
|
||||
console (logging.Logger): Logger instance.
|
||||
validator (BaseValidator): Validator instance.
|
||||
model (nn.Module): Model instance.
|
||||
callbacks (defaultdict): Dictionary of callbacks.
|
||||
@ -84,7 +83,6 @@ class BaseTrainer:
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
self.device = select_device(self.args.device, self.args.batch)
|
||||
self.check_resume()
|
||||
self.console = LOGGER
|
||||
self.validator = None
|
||||
self.model = None
|
||||
self.metrics = None
|
||||
@ -180,11 +178,12 @@ class BaseTrainer:
|
||||
if world_size > 1 and 'LOCAL_RANK' not in os.environ:
|
||||
cmd, file = generate_ddp_command(world_size, self) # security vulnerability in Snyk scans
|
||||
try:
|
||||
LOGGER.info(f'Running DDP command {cmd}')
|
||||
subprocess.run(cmd, check=True)
|
||||
except Exception as e:
|
||||
self.console.warning(e)
|
||||
LOGGER.warning(e)
|
||||
finally:
|
||||
ddp_cleanup(self, file)
|
||||
ddp_cleanup(self, str(file))
|
||||
else:
|
||||
self._do_train(RANK, world_size)
|
||||
|
||||
@ -193,7 +192,7 @@ class BaseTrainer:
|
||||
# os.environ['MASTER_PORT'] = '9020'
|
||||
torch.cuda.set_device(rank)
|
||||
self.device = torch.device('cuda', rank)
|
||||
self.console.info(f'DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||
LOGGER.info(f'DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||
dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo', rank=rank, world_size=world_size)
|
||||
|
||||
def _setup_train(self, rank, world_size):
|
||||
@ -262,10 +261,10 @@ class BaseTrainer:
|
||||
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
|
||||
last_opt_step = -1
|
||||
self.run_callbacks('on_train_start')
|
||||
self.log(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
||||
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
||||
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
||||
f'Starting training for {self.epochs} epochs...')
|
||||
LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
||||
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
||||
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
||||
f'Starting training for {self.epochs} epochs...')
|
||||
if self.args.close_mosaic:
|
||||
base_idx = (self.epochs - self.args.close_mosaic) * nb
|
||||
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
|
||||
@ -278,14 +277,14 @@ class BaseTrainer:
|
||||
pbar = enumerate(self.train_loader)
|
||||
# Update dataloader attributes (optional)
|
||||
if epoch == (self.epochs - self.args.close_mosaic):
|
||||
self.console.info('Closing dataloader mosaic')
|
||||
LOGGER.info('Closing dataloader mosaic')
|
||||
if hasattr(self.train_loader.dataset, 'mosaic'):
|
||||
self.train_loader.dataset.mosaic = False
|
||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
||||
|
||||
if rank in {-1, 0}:
|
||||
self.console.info(self.progress_string())
|
||||
LOGGER.info(self.progress_string())
|
||||
pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT)
|
||||
self.tloss = None
|
||||
self.optimizer.zero_grad()
|
||||
@ -372,12 +371,11 @@ class BaseTrainer:
|
||||
|
||||
if rank in {-1, 0}:
|
||||
# Do final val with best.pt
|
||||
self.log(f'\n{epoch - self.start_epoch + 1} epochs completed in '
|
||||
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
|
||||
LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in '
|
||||
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
|
||||
self.final_eval()
|
||||
if self.args.plots:
|
||||
self.plot_metrics()
|
||||
self.log(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||
self.run_callbacks('on_train_end')
|
||||
torch.cuda.empty_cache()
|
||||
self.run_callbacks('teardown')
|
||||
@ -450,18 +448,6 @@ class BaseTrainer:
|
||||
self.best_fitness = fitness
|
||||
return metrics, fitness
|
||||
|
||||
def log(self, text, rank=-1):
|
||||
"""
|
||||
Logs the given text to given ranks process if provided, otherwise logs to all ranks.
|
||||
|
||||
Args"
|
||||
text (str): text to log
|
||||
rank (List[Int]): process rank
|
||||
|
||||
"""
|
||||
if rank in {-1, 0}:
|
||||
self.console.info(text)
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||
|
||||
@ -521,7 +507,7 @@ class BaseTrainer:
|
||||
if f.exists():
|
||||
strip_optimizer(f) # strip optimizers
|
||||
if f is self.best:
|
||||
self.console.info(f'\nValidating {f}...')
|
||||
LOGGER.info(f'\nValidating {f}...')
|
||||
self.metrics = self.validator(model=f)
|
||||
self.metrics.pop('fitness', None)
|
||||
self.run_callbacks('on_fit_epoch_end')
|
||||
@ -564,7 +550,7 @@ class BaseTrainer:
|
||||
self.best_fitness = best_fitness
|
||||
self.start_epoch = start_epoch
|
||||
if start_epoch > (self.epochs - self.args.close_mosaic):
|
||||
self.console.info('Closing dataloader mosaic')
|
||||
LOGGER.info('Closing dataloader mosaic')
|
||||
if hasattr(self.train_loader.dataset, 'mosaic'):
|
||||
self.train_loader.dataset.mosaic = False
|
||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||
|
@ -44,7 +44,6 @@ class BaseValidator:
|
||||
Attributes:
|
||||
dataloader (DataLoader): Dataloader to use for validation.
|
||||
pbar (tqdm): Progress bar to update during validation.
|
||||
logger (logging.Logger): Logger to use for validation.
|
||||
args (SimpleNamespace): Configuration for the validator.
|
||||
model (nn.Module): Model to validate.
|
||||
data (dict): Data dictionary.
|
||||
@ -56,7 +55,7 @@ class BaseValidator:
|
||||
save_dir (Path): Directory to save results.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None):
|
||||
"""
|
||||
Initializes a BaseValidator instance.
|
||||
|
||||
@ -69,14 +68,13 @@ class BaseValidator:
|
||||
"""
|
||||
self.dataloader = dataloader
|
||||
self.pbar = pbar
|
||||
self.logger = logger or LOGGER
|
||||
self.args = args or get_cfg(DEFAULT_CFG)
|
||||
self.model = None
|
||||
self.data = None
|
||||
self.device = None
|
||||
self.batch_i = None
|
||||
self.training = True
|
||||
self.speed = None
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
self.jdict = None
|
||||
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
@ -123,14 +121,14 @@ class BaseValidator:
|
||||
self.device = model.device
|
||||
if not pt and not jit:
|
||||
self.args.batch = 1 # export.py models default to batch-size 1
|
||||
self.logger.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
||||
LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
||||
|
||||
if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'):
|
||||
self.data = check_det_dataset(self.args.data)
|
||||
elif self.args.task == 'classify':
|
||||
self.data = check_cls_dataset(self.args.data)
|
||||
else:
|
||||
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' not found ❌"))
|
||||
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
||||
|
||||
if self.device.type == 'cpu':
|
||||
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
||||
@ -179,7 +177,7 @@ class BaseValidator:
|
||||
stats = self.get_stats()
|
||||
self.check_stats(stats)
|
||||
self.print_results()
|
||||
self.speed = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image
|
||||
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1E3 for x in dt)))
|
||||
self.finalize_metrics()
|
||||
self.run_callbacks('on_val_end')
|
||||
if self.training:
|
||||
@ -187,11 +185,11 @@ class BaseValidator:
|
||||
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
|
||||
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
||||
else:
|
||||
self.logger.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
|
||||
self.speed)
|
||||
LOGGER.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
|
||||
tuple(self.speed.values()))
|
||||
if self.args.save_json and self.jdict:
|
||||
with open(str(self.save_dir / 'predictions.json'), 'w') as f:
|
||||
self.logger.info(f'Saving {f.name}...')
|
||||
LOGGER.info(f'Saving {f.name}...')
|
||||
json.dump(self.jdict, f) # flatten and save
|
||||
stats = self.eval_json(stats) # update stats
|
||||
if self.args.plots or self.args.save_json:
|
||||
|
Reference in New Issue
Block a user