From 2e7a533ac3e6bde645797abc3394d6d7bdf6274f Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Wed, 8 Feb 2023 05:33:25 +0800 Subject: [PATCH] Return metrics, Update docs (#846) --- README.md | 4 +- docs/cfg.md | 148 ++++++++++++++++-------------- docs/cli.md | 98 ++++---------------- docs/python.md | 4 +- docs/tasks/classification.md | 7 +- docs/tasks/detection.md | 9 +- docs/tasks/segmentation.md | 13 ++- ultralytics/yolo/engine/model.py | 1 + ultralytics/yolo/utils/metrics.py | 56 +++++++---- 9 files changed, 161 insertions(+), 179 deletions(-) diff --git a/README.md b/README.md index 0376d83..ed6970e 100644 --- a/README.md +++ b/README.md @@ -98,8 +98,8 @@ model = YOLO("yolov8n.yaml") # build a new model from scratch model = YOLO("yolov8n.pt") # load a pretrained model (recommended for training) # Use the model -results = model.train(data="coco128.yaml", epochs=3) # train the model -results = model.val() # evaluate model performance on the validation set +model.train(data="coco128.yaml", epochs=3) # train the model +metrics = model.val() # evaluate model performance on the validation set results = model("https://ultralytics.com/images/bus.jpg") # predict on an image success = model.export(format="onnx") # export the model to ONNX format ``` diff --git a/docs/cfg.md b/docs/cfg.md index 9c75ea2..cdea450 100644 --- a/docs/cfg.md +++ b/docs/cfg.md @@ -66,49 +66,50 @@ include the choice of optimizer, the choice of loss function, and the size and c is important to carefully tune and experiment with these settings to achieve the best possible performance for a given task. -| Key | Value | Description | -|-----------------|--------|-----------------------------------------------------------------------------| -| model | null | path to model file, i.e. yolov8n.pt, yolov8n.yaml | -| data | null | path to data file, i.e. i.e. coco128.yaml | -| epochs | 100 | number of epochs to train for | -| patience | 50 | epochs to wait for no observable improvement for early stopping of training | -| batch | 16 | number of images per batch (-1 for AutoBatch) | -| imgsz | 640 | size of input images as integer or w,h | -| save | True | save train checkpoints and predict results | -| cache | False | True/ram, disk or False. Use cache for data loading | -| device | null | device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu | -| workers | 8 | number of worker threads for data loading (per RANK if DDP) | -| project | null | project name | -| name | null | experiment name | -| exist_ok | False | whether to overwrite existing experiment | -| pretrained | False | whether to use a pretrained model | -| optimizer | 'SGD' | optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp'] | -| verbose | False | whether to print verbose output | -| seed | 0 | random seed for reproducibility | -| deterministic | True | whether to enable deterministic mode | -| single_cls | False | train multi-class data as single-class | -| image_weights | False | use weighted image selection for training | -| rect | False | support rectangular training | -| cos_lr | False | use cosine learning rate scheduler | -| close_mosaic | 10 | disable mosaic augmentation for final 10 epochs | -| resume | False | resume training from last checkpoint | -| lr0 | 0.01 | initial learning rate (i.e. SGD=1E-2, Adam=1E-3) | -| lrf | 0.01 | final learning rate (lr0 * lrf) | -| momentum | 0.937 | SGD momentum/Adam beta1 | -| weight_decay | 0.0005 | optimizer weight decay 5e-4 | -| warmup_epochs | 3.0 | warmup epochs (fractions ok) | -| warmup_momentum | 0.8 | warmup initial momentum | -| warmup_bias_lr | 0.1 | warmup initial bias lr | -| box | 7.5 | box loss gain | -| cls | 0.5 | cls loss gain (scale with pixels) | -| dfl | 1.5 | dfl loss gain | -| fl_gamma | 0.0 | focal loss gamma (efficientDet default gamma=1.5) | -| label_smoothing | 0.0 | label smoothing (fraction) | -| nbs | 64 | nominal batch size | -| overlap_mask | True | masks should overlap during training (segment train only) | -| mask_ratio | 4 | mask downsample ratio (segment train only) | -| dropout | 0.0 | use dropout regularization (classify train only) | -| val | True | validate/test during training | +| Key | Value | Description | +|-----------------|--------|--------------------------------------------------------------------------------| +| model | null | path to model file, i.e. yolov8n.pt, yolov8n.yaml | +| data | null | path to data file, i.e. i.e. coco128.yaml | +| epochs | 100 | number of epochs to train for | +| patience | 50 | epochs to wait for no observable improvement for early stopping of training | +| batch | 16 | number of images per batch (-1 for AutoBatch) | +| imgsz | 640 | size of input images as integer or w,h | +| save | True | save train checkpoints and predict results | +| cache | False | True/ram, disk or False. Use cache for data loading | +| device | null | device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu | +| workers | 8 | number of worker threads for data loading (per RANK if DDP) | +| project | null | project name | +| name | null | experiment name | +| exist_ok | False | whether to overwrite existing experiment | +| pretrained | False | whether to use a pretrained model | +| optimizer | 'SGD' | optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp'] | +| verbose | False | whether to print verbose output | +| seed | 0 | random seed for reproducibility | +| deterministic | True | whether to enable deterministic mode | +| single_cls | False | train multi-class data as single-class | +| image_weights | False | use weighted image selection for training | +| rect | False | support rectangular training | +| cos_lr | False | use cosine learning rate scheduler | +| close_mosaic | 10 | disable mosaic augmentation for final 10 epochs | +| resume | False | resume training from last checkpoint | +| lr0 | 0.01 | initial learning rate (i.e. SGD=1E-2, Adam=1E-3) | +| lrf | 0.01 | final learning rate (lr0 * lrf) | +| momentum | 0.937 | SGD momentum/Adam beta1 | +| weight_decay | 0.0005 | optimizer weight decay 5e-4 | +| warmup_epochs | 3.0 | warmup epochs (fractions ok) | +| warmup_momentum | 0.8 | warmup initial momentum | +| warmup_bias_lr | 0.1 | warmup initial bias lr | +| box | 7.5 | box loss gain | +| cls | 0.5 | cls loss gain (scale with pixels) | +| dfl | 1.5 | dfl loss gain | +| fl_gamma | 0.0 | focal loss gamma (efficientDet default gamma=1.5) | +| label_smoothing | 0.0 | label smoothing (fraction) | +| nbs | 64 | nominal batch size | +| overlap_mask | True | masks should overlap during training (segment train only) | +| mask_ratio | 4 | mask downsample ratio (segment train only) | +| dropout | 0.0 | use dropout regularization (classify train only) | +| val | True | validate/test during training | +| min_memory | False | minimize memory footprint loss function, choices=[False, True, ] | ### Prediction @@ -120,22 +121,28 @@ presence of additional features such as masks or multiple labels per box, and th for. It is important to carefully tune and experiment with these settings to achieve the best possible performance for a given task. -| Key | Value | Description | -|----------------|----------------------|---------------------------------------------------------| -| source | 'ultralytics/assets' | source directory for images or videos | -| show | False | show results if possible | -| save_txt | False | save results as .txt file | -| save_conf | False | save results with confidence scores | -| save_crop | False | save cropped images with results | -| hide_labels | False | hide labels | -| hide_conf | False | hide confidence scores | -| vid_stride | False | video frame-rate stride | -| line_thickness | 3 | bounding box thickness (pixels) | -| visualize | False | visualize model features | -| augment | False | apply image augmentation to prediction sources | -| agnostic_nms | False | class-agnostic NMS | -| retina_masks | False | use high-resolution segmentation masks | -| classes | null | filter results by class, i.e. class=0, or class=[0,2,3] | +| Key | Value | Description | +|----------------|----------------------|----------------------------------------------------------| +| source | 'ultralytics/assets' | source directory for images or videos | +| conf | 0.25 | object confidence threshold for detection | +| iou | 0.7 | intersection over union (IoU) threshold for NMS | +| half | False | use half precision (FP16) | +| device | null | device to run on, i.e. cuda device=0/1/2/3 or device=cpu | +| show | False | show results if possible | +| save_txt | False | save results as .txt file | +| save_conf | False | save results with confidence scores | +| save_crop | False | save cropped images with results | +| hide_labels | False | hide labels | +| hide_conf | False | hide confidence scores | +| max_det | 300 | maximum number of detections per image | +| vid_stride | False | video frame-rate stride | +| line_thickness | 3 | bounding box thickness (pixels) | +| visualize | False | visualize model features | +| augment | False | apply image augmentation to prediction sources | +| agnostic_nms | False | class-agnostic NMS | +| retina_masks | False | use high-resolution segmentation masks | +| classes | null | filter results by class, i.e. class=0, or class=[0,2,3] | +| box | True | Show boxes in segmentation predictions | ### Validation @@ -147,17 +154,18 @@ process include the size and composition of the validation dataset and the speci is important to carefully tune and experiment with these settings to ensure that the model is performing well on the validation dataset and to detect and prevent overfitting. -| Key | Value | Description | -|-------------|-------|-----------------------------------------------------------------------------| -| save_json | False | save results to JSON file | -| save_hybrid | False | save hybrid version of labels (labels + additional predictions) | -| conf | 0.001 | object confidence threshold for detection (default 0.25 predict, 0.001 val) | -| iou | 0.6 | intersection over union (IoU) threshold for NMS | -| max_det | 300 | maximum number of detections per image | -| half | True | use half precision (FP16) | -| dnn | False | use OpenCV DNN for ONNX inference | -| plots | False | show plots during training | -| rect | False | support rectangular evaluation | +| Key | Value | Description | +|-------------|-------|-----------------------------------------------------------------| +| save_json | False | save results to JSON file | +| save_hybrid | False | save hybrid version of labels (labels + additional predictions) | +| conf | 0.001 | object confidence threshold for detection | +| iou | 0.6 | intersection over union (IoU) threshold for NMS | +| max_det | 300 | maximum number of detections per image | +| half | True | use half precision (FP16) | +| device | null | device to run on, i.e. cuda device=0/1/2/3 or device=cpu | +| dnn | False | use OpenCV DNN for ONNX inference | +| plots | False | show plots during training | +| rect | False | support rectangular evaluation | ### Export diff --git a/docs/cli.md b/docs/cli.md index fa111cd..0f809b8 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -35,103 +35,41 @@ the [Configuration](cfg.md) page. !!! example "" - === "CLI" - - ```bash - yolo detect train data=coco128.yaml model=yolov8n.pt epochs=100 imgsz=640 - ``` - - === "Python" - - ```python - from ultralytics import YOLO - - # Load a model - model = YOLO("yolov8n.yaml") # build a new model from scratch - model = YOLO("yolov8n.pt") # load a pretrained model (recommended for training) - - # Train the model - results = model.train(data="coco128.yaml", epochs=100, imgsz=640) - ``` - + ```bash + yolo detect train data=coco128.yaml model=yolov8n.pt epochs=100 imgsz=640 + yolo detect train resume model=last.pt # resume training + ``` ## Val Validate trained YOLOv8n model accuracy on the COCO128 dataset. No argument need to passed as the `model` retains it's training `data` and arguments as model attributes. !!! example "" - - === "CLI" - - ```bash - yolo detect val model=yolov8n.pt # val official model - yolo detect val model=path/to/best.pt # val custom model - ``` - - === "Python" - - ```python - from ultralytics import YOLO - - # Load a model - model = YOLO("yolov8n.pt") # load an official model - model = YOLO("path/to/best.pt") # load a custom model - - # Validate the model - results = model.val() # no arguments needed, dataset and settings remembered - ``` - + + ```bash + yolo detect val model=yolov8n.pt # val official model + yolo detect val model=path/to/best.pt # val custom model + ``` ## Predict Use a trained YOLOv8n model to run predictions on images. !!! example "" - === "CLI" - - ```bash - yolo detect predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg" # predict with official model - yolo detect predict model=path/to/best.pt source="https://ultralytics.com/images/bus.jpg" # predict with custom model - ``` - - === "Python" - - ```python - from ultralytics import YOLO - - # Load a model - model = YOLO("yolov8n.pt") # load an official model - model = YOLO("path/to/best.pt") # load a custom model - - # Predict with the model - results = model("https://ultralytics.com/images/bus.jpg") # predict on an image - ``` - + ```bash + yolo detect predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg" # predict with official model + yolo detect predict model=path/to/best.pt source="https://ultralytics.com/images/bus.jpg" # predict with custom model + ``` ## Export Export a YOLOv8n model to a different format like ONNX, CoreML, etc. !!! example "" - - === "CLI" - - ```bash - yolo export model=yolov8n.pt format=onnx # export official model - yolo export model=path/to/best.pt format=onnx # export custom trained model - ``` - - === "Python" - ```python - from ultralytics import YOLO - - # Load a model - model = YOLO("yolov8n.pt") # load an official model - model = YOLO("path/to/best.pt") # load a custom trained - - # Export the model - model.export(format="onnx") - ``` + ```bash + yolo export model=yolov8n.pt format=onnx # export official model + yolo export model=path/to/best.pt format=onnx # export custom trained model + ``` Available YOLOv8 export formats include: @@ -194,4 +132,4 @@ like `imgsz=320` in this example: ```bash yolo copy-cfg yolo cfg=default_copy.yaml imgsz=320 - ``` \ No newline at end of file + ``` diff --git a/docs/python.md b/docs/python.md index 22a88d7..90770a6 100644 --- a/docs/python.md +++ b/docs/python.md @@ -20,7 +20,9 @@ The simplest way of simply using YOLOv8 directly in a Python environment. === "Resume" ```python - TODO: Resume feature is under development and should be released soon. + # TODO: Resume feature is under development and should be released soon. + model = YOLO("last.pt") + model.train(resume=True) ``` !!! example "Val" diff --git a/docs/tasks/classification.md b/docs/tasks/classification.md index 34e9dea..0f1ac3d 100644 --- a/docs/tasks/classification.md +++ b/docs/tasks/classification.md @@ -30,7 +30,7 @@ see the [Configuration](../cfg.md) page. model = YOLO("yolov8n-cls.pt") # load a pretrained model (recommended for training) # Train the model - results = model.train(data="mnist160", epochs=100, imgsz=64) + model.train(data="mnist160", epochs=100, imgsz=64) ``` === "CLI" @@ -55,7 +55,9 @@ it's training `data` and arguments as model attributes. model = YOLO("path/to/best.pt") # load a custom model # Validate the model - results = model.val() # no arguments needed, dataset and settings remembered + metrics = model.val() # no arguments needed, dataset and settings remembered + metrics.top1 # top1 accuracy + metrics.top5 # top5 accuracy ``` === "CLI" @@ -88,6 +90,7 @@ Use a trained YOLOv8n-cls model to run predictions on images. yolo classify predict model=yolov8n-cls.pt source="https://ultralytics.com/images/bus.jpg" # predict with official model yolo classify predict model=path/to/best.pt source="https://ultralytics.com/images/bus.jpg" # predict with custom model ``` +Read more details of `predict` in our [Predict](https://docs.ultralytics.com/predict/) page. ## Export diff --git a/docs/tasks/detection.md b/docs/tasks/detection.md index ac2af7c..4374de2 100644 --- a/docs/tasks/detection.md +++ b/docs/tasks/detection.md @@ -30,7 +30,7 @@ the [Configuration](../cfg.md) page. model = YOLO("yolov8n.pt") # load a pretrained model (recommended for training) # Train the model - results = model.train(data="coco128.yaml", epochs=100, imgsz=640) + model.train(data="coco128.yaml", epochs=100, imgsz=640) ``` === "CLI" @@ -55,7 +55,11 @@ training `data` and arguments as model attributes. model = YOLO("path/to/best.pt") # load a custom model # Validate the model - results = model.val() # no arguments needed, dataset and settings remembered + metrics = model.val() # no arguments needed, dataset and settings remembered + metrics.box.map # map50-95 + metrics.box.map50 # map50 + metrics.box.map75 # map75 + metrics.box.maps # a list contains map50-95 of each category ``` === "CLI" @@ -88,6 +92,7 @@ Use a trained YOLOv8n model to run predictions on images. yolo detect predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg" # predict with official model yolo detect predict model=path/to/best.pt source="https://ultralytics.com/images/bus.jpg" # predict with custom model ``` +Read more details of `predict` in our [Predict](https://docs.ultralytics.com/predict/) page. ## Export diff --git a/docs/tasks/segmentation.md b/docs/tasks/segmentation.md index 0a0ffc1..0dcdc54 100644 --- a/docs/tasks/segmentation.md +++ b/docs/tasks/segmentation.md @@ -30,7 +30,7 @@ arguments see the [Configuration](../cfg.md) page. model = YOLO("yolov8n-seg.pt") # load a pretrained model (recommended for training) # Train the model - results = model.train(data="coco128-seg.yaml", epochs=100, imgsz=640) + model.train(data="coco128-seg.yaml", epochs=100, imgsz=640) ``` === "CLI" @@ -55,7 +55,15 @@ retains it's training `data` and arguments as model attributes. model = YOLO("path/to/best.pt") # load a custom model # Validate the model - results = model.val() # no arguments needed, dataset and settings remembered + metrics = model.val() # no arguments needed, dataset and settings remembered + metrics.box.map # map50-95(B) + metrics.box.map50 # map50(B) + metrics.box.map75 # map75(B) + metrics.box.maps # a list contains map50-95(B) of each category + metrics.seg.map # map50-95(M) + metrics.seg.map50 # map50(M) + metrics.seg.map75 # map75(M) + metrics.seg.maps # a list contains map50-95(M) of each category ``` === "CLI" @@ -88,6 +96,7 @@ Use a trained YOLOv8n-seg model to run predictions on images. yolo segment predict model=yolov8n-seg.pt source="https://ultralytics.com/images/bus.jpg" # predict with official model yolo segment predict model=path/to/best.pt source="https://ultralytics.com/images/bus.jpg" # predict with custom model ``` +Read more details of `predict` in our [Predict](https://docs.ultralytics.com/predict/) page. ## Export diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index 5fba06d..10db5d1 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -168,6 +168,7 @@ class YOLO: validator = self.ValidatorClass(args=args) validator(model=self.model) + return validator.metrics @smart_inference_mode() def export(self, **kwargs): diff --git a/ultralytics/yolo/utils/metrics.py b/ultralytics/yolo/utils/metrics.py index cd8a88a..3c10891 100644 --- a/ultralytics/yolo/utils/metrics.py +++ b/ultralytics/yolo/utils/metrics.py @@ -418,6 +418,7 @@ class Metric: self.f1 = [] # (nc, ) self.all_ap = [] # (nc, 10) self.ap_class_index = [] # (nc, ) + self.nc = 0 @property def ap50(self): @@ -459,6 +460,14 @@ class Metric: """ return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0 + @property + def map75(self): + """Mean AP@0.75 of all classes. + Return: + float. + """ + return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0 + @property def map(self): """Mean AP@0.5:0.95 of all classes. @@ -475,8 +484,10 @@ class Metric: """class-aware result, return p[i], r[i], ap50[i], ap[i]""" return self.p[i], self.r[i], self.ap50[i], self.ap[i] - def get_maps(self, nc): - maps = np.zeros(nc) + self.map + @property + def maps(self): + """mAP of each class""" + maps = np.zeros(self.nc) + self.map for i, c in enumerate(self.ap_class_index): maps[c] = self.ap[i] return maps @@ -500,33 +511,35 @@ class DetMetrics: self.save_dir = save_dir self.plot = plot self.names = names - self.metric = Metric() + self.box = Metric() def process(self, tp, conf, pred_cls, target_cls): results = ap_per_class(tp, conf, pred_cls, target_cls, plot=self.plot, save_dir=self.save_dir, names=self.names)[2:] - self.metric.update(results) + self.box.nc = len(self.names) + self.box.update(results) @property def keys(self): return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"] def mean_results(self): - return self.metric.mean_results() + return self.box.mean_results() def class_result(self, i): - return self.metric.class_result(i) + return self.box.class_result(i) - def get_maps(self, nc): - return self.metric.get_maps(nc) + @property + def maps(self): + return self.box.maps @property def fitness(self): - return self.metric.fitness() + return self.box.fitness() @property def ap_class_index(self): - return self.metric.ap_class_index + return self.box.ap_class_index @property def results_dict(self): @@ -539,8 +552,8 @@ class SegmentMetrics: self.save_dir = save_dir self.plot = plot self.names = names - self.metric_box = Metric() - self.metric_mask = Metric() + self.box = Metric() + self.seg = Metric() def process(self, tp_m, tp_b, conf, pred_cls, target_cls): results_mask = ap_per_class(tp_m, @@ -551,7 +564,8 @@ class SegmentMetrics: save_dir=self.save_dir, names=self.names, prefix="Mask")[2:] - self.metric_mask.update(results_mask) + self.seg.nc = len(self.names) + self.seg.update(results_mask) results_box = ap_per_class(tp_b, conf, pred_cls, @@ -560,7 +574,8 @@ class SegmentMetrics: save_dir=self.save_dir, names=self.names, prefix="Box")[2:] - self.metric_box.update(results_box) + self.box.nc = len(self.names) + self.box.update(results_box) @property def keys(self): @@ -569,22 +584,23 @@ class SegmentMetrics: "metrics/precision(M)", "metrics/recall(M)", "metrics/mAP50(M)", "metrics/mAP50-95(M)"] def mean_results(self): - return self.metric_box.mean_results() + self.metric_mask.mean_results() + return self.box.mean_results() + self.seg.mean_results() def class_result(self, i): - return self.metric_box.class_result(i) + self.metric_mask.class_result(i) + return self.box.class_result(i) + self.seg.class_result(i) - def get_maps(self, nc): - return self.metric_box.get_maps(nc) + self.metric_mask.get_maps(nc) + @property + def maps(self): + return self.box.maps + self.seg.maps @property def fitness(self): - return self.metric_mask.fitness() + self.metric_box.fitness() + return self.seg.fitness() + self.box.fitness() @property def ap_class_index(self): # boxes and masks have the same ap_class_index - return self.metric_box.ap_class_index + return self.box.ap_class_index @property def results_dict(self):