ultralytics 8.0.89 SAM predict and auto-annotate (#2298)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
Co-authored-by: Paula Derrenger <107626595+pderrenger@users.noreply.github.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Snyk bot <snyk-bot@snyk.io>
Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
Glenn Jocher
2023-04-28 00:36:50 +02:00
committed by GitHub
parent 3e118f6170
commit 243fc4b1fe
44 changed files with 2915 additions and 440 deletions

View File

@ -13,20 +13,8 @@ try:
except (ImportError, AssertionError):
comet_ml = None
COMET_MODE = os.getenv('COMET_MODE', 'online')
COMET_MODEL_NAME = os.getenv('COMET_MODEL_NAME', 'YOLOv8')
# Determines how many batches of image predictions to log from the validation set
COMET_EVAL_BATCH_LOGGING_INTERVAL = int(os.getenv('COMET_EVAL_BATCH_LOGGING_INTERVAL', 1))
# Determines whether to log confusion matrix every evaluation epoch
COMET_EVAL_LOG_CONFUSION_MATRIX = (os.getenv('COMET_EVAL_LOG_CONFUSION_MATRIX', 'true').lower() == 'true')
# Determines whether to log image predictions every evaluation epoch
COMET_EVAL_LOG_IMAGE_PREDICTIONS = (os.getenv('COMET_EVAL_LOG_IMAGE_PREDICTIONS', 'true').lower() == 'true')
COMET_MAX_IMAGE_PREDICTIONS = int(os.getenv('COMET_MAX_IMAGE_PREDICTIONS', 100))
# Ensures certain logging functions only run for supported tasks
COMET_SUPPORTED_TASKS = ['detect']
# Scales reported confidence scores (0.0-1.0) by this value
COMET_MAX_CONFIDENCE_SCORE = int(os.getenv('COMET_MAX_CONFIDENCE_SCORE', 100))
# Names of plots created by YOLOv8 that are logged to Comet
EVALUATION_PLOT_NAMES = 'F1_curve', 'P_curve', 'R_curve', 'PR_curve', 'confusion_matrix'
@ -35,6 +23,35 @@ LABEL_PLOT_NAMES = 'labels', 'labels_correlogram'
_comet_image_prediction_count = 0
def _get_comet_mode():
return os.getenv('COMET_MODE', 'online')
def _get_comet_model_name():
return os.getenv('COMET_MODEL_NAME', 'YOLOv8')
def _get_eval_batch_logging_interval():
return int(os.getenv('COMET_EVAL_BATCH_LOGGING_INTERVAL', 1))
def _get_max_image_predictions_to_log():
return int(os.getenv('COMET_MAX_IMAGE_PREDICTIONS', 100))
def _scale_confidence_score(score):
scale = float(os.getenv('COMET_MAX_CONFIDENCE_SCORE', 100.0))
return score * scale
def _should_log_confusion_matrix():
return os.getenv('COMET_EVAL_LOG_CONFUSION_MATRIX', 'true').lower() == 'true'
def _should_log_image_predictions():
return os.getenv('COMET_EVAL_LOG_IMAGE_PREDICTIONS', 'true').lower() == 'true'
def _get_experiment_type(mode, project_name):
"""Return an experiment based on mode and project name."""
if mode == 'offline':
@ -48,13 +65,14 @@ def _create_experiment(args):
if RANK not in (-1, 0):
return
try:
experiment = _get_experiment_type(COMET_MODE, args.project)
comet_mode = _get_comet_mode()
experiment = _get_experiment_type(comet_mode, args.project)
experiment.log_parameters(vars(args))
experiment.log_others({
'eval_batch_logging_interval': COMET_EVAL_BATCH_LOGGING_INTERVAL,
'log_confusion_matrix': COMET_EVAL_LOG_CONFUSION_MATRIX,
'log_image_predictions': COMET_EVAL_LOG_IMAGE_PREDICTIONS,
'max_image_predictions': COMET_MAX_IMAGE_PREDICTIONS, })
'eval_batch_logging_interval': _get_eval_batch_logging_interval(),
'log_confusion_matrix': _should_log_confusion_matrix(),
'log_image_predictions': _should_log_image_predictions(),
'max_image_predictions': _get_max_image_predictions_to_log(), })
experiment.log_other('Created from', 'yolov8')
except Exception as e:
@ -74,7 +92,12 @@ def _fetch_trainer_metadata(trainer):
save_interval = curr_epoch % save_period == 0
save_assets = save and save_period > 0 and save_interval and not final_epoch
return dict(curr_epoch=curr_epoch, curr_step=curr_step, save_assets=save_assets, final_epoch=final_epoch)
return dict(
curr_epoch=curr_epoch,
curr_step=curr_step,
save_assets=save_assets,
final_epoch=final_epoch,
)
def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad):
@ -117,7 +140,10 @@ def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, c
data = []
for box, label in zip(bboxes, cls_labels):
box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad)
data.append({'boxes': [box], 'label': f'gt_{label}', 'score': COMET_MAX_CONFIDENCE_SCORE})
data.append({
'boxes': [box],
'label': f'gt_{label}',
'score': _scale_confidence_score(1.0), })
return {'name': 'ground_truth', 'data': data}
@ -135,7 +161,7 @@ def _format_prediction_annotations_for_detection(image_path, metadata, class_lab
data = []
for prediction in predictions:
boxes = prediction['bbox']
score = prediction['score'] * COMET_MAX_CONFIDENCE_SCORE
score = _scale_confidence_score(prediction['score'])
cls_label = prediction['category_id']
if class_label_map:
cls_label = str(class_label_map[cls_label])
@ -207,13 +233,16 @@ def _log_image_predictions(experiment, validator, curr_step):
dataloader = validator.dataloader
class_label_map = validator.names
batch_logging_interval = _get_eval_batch_logging_interval()
max_image_predictions = _get_max_image_predictions_to_log()
for batch_idx, batch in enumerate(dataloader):
if (batch_idx + 1) % COMET_EVAL_BATCH_LOGGING_INTERVAL != 0:
if (batch_idx + 1) % batch_logging_interval != 0:
continue
image_paths = batch['im_file']
for img_idx, image_path in enumerate(image_paths):
if _comet_image_prediction_count >= COMET_MAX_IMAGE_PREDICTIONS:
if _comet_image_prediction_count >= max_image_predictions:
return
image_path = Path(image_path)
@ -244,8 +273,9 @@ def _log_plots(experiment, trainer):
def _log_model(experiment, trainer):
"""Log the best-trained model to Comet.ml."""
model_name = _get_comet_model_name()
experiment.log_model(
COMET_MODEL_NAME,
model_name,
file_or_folder=str(trainer.best),
file_name='best.pt',
overwrite=True,
@ -255,7 +285,8 @@ def _log_model(experiment, trainer):
def on_pretrain_routine_start(trainer):
"""Creates or resumes a CometML experiment at the start of a YOLO pre-training routine."""
experiment = comet_ml.get_global_experiment()
if not experiment:
is_alive = getattr(experiment, 'alive', False)
if not experiment or not is_alive:
_create_experiment(trainer.args)
@ -296,16 +327,16 @@ def on_fit_epoch_end(trainer):
model_info = {
'model/parameters': get_num_params(trainer.model),
'model/GFLOPs': round(get_flops(trainer.model), 3),
'model/speed(ms)': round(trainer.validator.speed['inference'], 3)}
'model/speed(ms)': round(trainer.validator.speed['inference'], 3), }
experiment.log_metrics(model_info, step=curr_step, epoch=curr_epoch)
if not save_assets:
return
_log_model(experiment, trainer)
if COMET_EVAL_LOG_CONFUSION_MATRIX:
if _should_log_confusion_matrix():
_log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
if COMET_EVAL_LOG_IMAGE_PREDICTIONS:
if _should_log_image_predictions():
_log_image_predictions(experiment, trainer.validator, curr_step)