`ultralytics 8.0.155` allow `imgsz` and `batch` resume changes (#4366)

Co-authored-by: Mostafa Nemati <58460889+monemati@users.noreply.github.com>
Co-authored-by: Eduard Voiculescu <eduardvoiculescu95@gmail.com>
single_channel
Glenn Jocher 1 year ago committed by GitHub
parent 60cad0c592
commit 9a0555eca4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -321,7 +321,7 @@ All supported arguments:
| `augment` | `bool` | `False` | apply image augmentation to prediction sources | | `augment` | `bool` | `False` | apply image augmentation to prediction sources |
| `agnostic_nms` | `bool` | `False` | class-agnostic NMS | | `agnostic_nms` | `bool` | `False` | class-agnostic NMS |
| `retina_masks` | `bool` | `False` | use high-resolution segmentation masks | | `retina_masks` | `bool` | `False` | use high-resolution segmentation masks |
| `classes` | `None or list` | `None` | filter results by class, i.e. class=0, or class=[0,2,3] | | `classes` | `None or list` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] |
| `boxes` | `bool` | `True` | Show boxes in segmentation predictions | | `boxes` | `bool` | `True` | Show boxes in segmentation predictions |
## Image and Video Formats ## Image and Video Formats

@ -9,50 +9,18 @@ keywords: Ultralytics, Trackers Utils, Matching, merge_matches, linear_assignmen
Full source code for this file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/trackers/utils/matching.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/trackers/utils/matching.py). Help us fix any issues you see by submitting a [Pull Request](https://docs.ultralytics.com/help/contributing/) 🛠️. Thank you 🙏! Full source code for this file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/trackers/utils/matching.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/trackers/utils/matching.py). Help us fix any issues you see by submitting a [Pull Request](https://docs.ultralytics.com/help/contributing/) 🛠️. Thank you 🙏!
---
## ::: ultralytics.trackers.utils.matching.merge_matches
<br><br>
---
## ::: ultralytics.trackers.utils.matching._indices_to_matches
<br><br>
--- ---
## ::: ultralytics.trackers.utils.matching.linear_assignment ## ::: ultralytics.trackers.utils.matching.linear_assignment
<br><br> <br><br>
---
## ::: ultralytics.trackers.utils.matching.ious
<br><br>
--- ---
## ::: ultralytics.trackers.utils.matching.iou_distance ## ::: ultralytics.trackers.utils.matching.iou_distance
<br><br> <br><br>
---
## ::: ultralytics.trackers.utils.matching.v_iou_distance
<br><br>
--- ---
## ::: ultralytics.trackers.utils.matching.embedding_distance ## ::: ultralytics.trackers.utils.matching.embedding_distance
<br><br> <br><br>
---
## ::: ultralytics.trackers.utils.matching.gate_cost_matrix
<br><br>
---
## ::: ultralytics.trackers.utils.matching.fuse_motion
<br><br>
---
## ::: ultralytics.trackers.utils.matching.fuse_iou
<br><br>
--- ---
## ::: ultralytics.trackers.utils.matching.fuse_score ## ::: ultralytics.trackers.utils.matching.fuse_score
<br><br> <br><br>
---
## ::: ultralytics.trackers.utils.matching.bbox_ious
<br><br>

@ -154,7 +154,7 @@ The prediction settings for YOLO models encompass a range of hyperparameters and
| `augment` | `False` | apply image augmentation to prediction sources | | `augment` | `False` | apply image augmentation to prediction sources |
| `agnostic_nms` | `False` | class-agnostic NMS | | `agnostic_nms` | `False` | class-agnostic NMS |
| `retina_masks` | `False` | use high-resolution segmentation masks | | `retina_masks` | `False` | use high-resolution segmentation masks |
| `classes` | `None` | filter results by class, i.e. class=0, or class=[0,2,3] | | `classes` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] |
| `boxes` | `True` | Show boxes in segmentation predictions | | `boxes` | `True` | Show boxes in segmentation predictions |
[Predict Guide](../modes/predict.md){ .md-button .md-button--primary} [Predict Guide](../modes/predict.md){ .md-button .md-button--primary}

@ -8,7 +8,7 @@ This example demonstrates how to perform inference using YOLOv8 and YOLOv5 model
git clone ultralytics git clone ultralytics
cd ultralytics cd ultralytics
pip install . pip install .
cd examples/cpp_ cd examples/YOLOv8-CPP-Inference
# Add a **yolov8\_.onnx** and/or **yolov5\_.onnx** model(s) to the ultralytics folder. # Add a **yolov8\_.onnx** and/or **yolov5\_.onnx** model(s) to the ultralytics folder.
# Edit the **main.cpp** to change the **projectBasePath** to match your user. # Edit the **main.cpp** to change the **projectBasePath** to match your user.

@ -55,7 +55,7 @@ def test_predict_online(task, model, data):
mode = 'track' if task in ('detect', 'segment', 'pose') else 'predict' # mode for video inference mode = 'track' if task in ('detect', 'segment', 'pose') else 'predict' # mode for video inference
model = WEIGHT_DIR / model model = WEIGHT_DIR / model
run(f'yolo predict model={model}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32') run(f'yolo predict model={model}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32')
run(f'yolo {mode} model={model}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=32') run(f'yolo {mode} model={model}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=96')
# Run Python YouTube tracking because CLI is broken. TODO: fix CLI YouTube # Run Python YouTube tracking because CLI is broken. TODO: fix CLI YouTube
# run(f'yolo {mode} model={model}.pt source=https://youtu.be/G17sBkb38XQ imgsz=32 tracker=bytetrack.yaml') # run(f'yolo {mode} model={model}.pt source=https://youtu.be/G17sBkb38XQ imgsz=32 tracker=bytetrack.yaml')

@ -18,6 +18,7 @@ WEIGHTS_DIR = Path(SETTINGS['weights_dir'])
MODEL = WEIGHTS_DIR / 'path with spaces' / 'yolov8n.pt' # test spaces in path MODEL = WEIGHTS_DIR / 'path with spaces' / 'yolov8n.pt' # test spaces in path
CFG = 'yolov8n.yaml' CFG = 'yolov8n.yaml'
SOURCE = ROOT / 'assets/bus.jpg' SOURCE = ROOT / 'assets/bus.jpg'
TMP = (ROOT / '../tests/tmp').resolve() # temp directory for test files
SOURCE_GREYSCALE = Path(f'{SOURCE.parent / SOURCE.stem}_greyscale.jpg') SOURCE_GREYSCALE = Path(f'{SOURCE.parent / SOURCE.stem}_greyscale.jpg')
SOURCE_RGBA = Path(f'{SOURCE.parent / SOURCE.stem}_4ch.png') SOURCE_RGBA = Path(f'{SOURCE.parent / SOURCE.stem}_4ch.png')
@ -92,7 +93,7 @@ def test_predict_grey_and_4ch():
def test_track_stream(): def test_track_stream():
# Test YouTube streaming inference (short 10 frame video) with non-default ByteTrack tracker # Test YouTube streaming inference (short 10 frame video) with non-default ByteTrack tracker
model = YOLO(MODEL) model = YOLO(MODEL)
model.track('https://youtu.be/G17sBkb38XQ', imgsz=32, tracker='bytetrack.yaml') model.track('https://youtu.be/G17sBkb38XQ', imgsz=96, tracker='bytetrack.yaml')
def test_val(): def test_val():
@ -232,16 +233,15 @@ def test_data_utils():
# from ultralytics.utils.files import WorkingDirectory # from ultralytics.utils.files import WorkingDirectory
# with WorkingDirectory(ROOT.parent / 'tests'): # with WorkingDirectory(ROOT.parent / 'tests'):
Path('tests/coco8.zip').unlink(missing_ok=True) shutil.rmtree(TMP, ignore_errors=True)
Path('coco8.zip').unlink(missing_ok=True) TMP.mkdir(parents=True)
download('https://github.com/ultralytics/hub/raw/master/example_datasets/coco8.zip', unzip=False) download('https://github.com/ultralytics/hub/raw/master/example_datasets/coco8.zip', unzip=False)
shutil.move('coco8.zip', 'tests') shutil.move('coco8.zip', TMP)
shutil.rmtree('tests/coco8', ignore_errors=True) stats = HUBDatasetStats(TMP / 'coco8.zip', task='detect')
stats = HUBDatasetStats('tests/coco8.zip', task='detect')
stats.get_json(save=False) stats.get_json(save=False)
stats.process_images() stats.process_images()
autosplit('tests/coco8') autosplit(TMP / 'coco8')
zip_directory('tests/coco8/images/val') # zip zip_directory(TMP / 'coco8/images/val') # zip
shutil.rmtree('tests/coco8', ignore_errors=True) shutil.rmtree(TMP)
shutil.rmtree('tests/coco8-hub', ignore_errors=True)

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.154' __version__ = '8.0.155'
from ultralytics.hub import start from ultralytics.hub import start
from ultralytics.models import RTDETR, SAM, YOLO from ultralytics.models import RTDETR, SAM, YOLO

@ -64,7 +64,7 @@ line_width: # (int, optional) line width of the bounding boxes, auto if missin
visualize: False # (bool) visualize model features visualize: False # (bool) visualize model features
augment: False # (bool) apply image augmentation to prediction sources augment: False # (bool) apply image augmentation to prediction sources
agnostic_nms: False # (bool) class-agnostic NMS agnostic_nms: False # (bool) class-agnostic NMS
classes: # (int | list[int], optional) filter results by class, i.e. class=0, or class=[0,2,3] classes: # (int | list[int], optional) filter results by class, i.e. classes=0, or classes=[0,2,3]
retina_masks: False # (bool) use high-resolution segmentation masks retina_masks: False # (bool) use high-resolution segmentation masks
boxes: True # (bool) Show boxes in segmentation predictions boxes: True # (bool) Show boxes in segmentation predictions

@ -120,7 +120,7 @@ def check_source(source):
screenshot = source.lower() == 'screen' screenshot = source.lower() == 'screen'
if is_url and is_file: if is_url and is_file:
source = check_file(source) # download source = check_file(source) # download
elif isinstance(source, tuple(LOADERS)): elif isinstance(source, LOADERS):
in_memory = True in_memory = True
elif isinstance(source, (list, tuple)): elif isinstance(source, (list, tuple)):
source = autocast_list(source) # convert all list elements to PIL or np arrays source = autocast_list(source) # convert all list elements to PIL or np arrays

@ -98,7 +98,7 @@ class LoadStreams:
def close(self): def close(self):
"""Close stream loader and release resources.""" """Close stream loader and release resources."""
self.running = False # stop flag for Thread self.running = False # stop flag for Thread
for i, thread in enumerate(self.threads): for thread in self.threads:
if thread.is_alive(): if thread.is_alive():
thread.join(timeout=5) # Add timeout thread.join(timeout=5) # Add timeout
for cap in self.caps: # Iterate through the stored VideoCapture objects for cap in self.caps: # Iterate through the stored VideoCapture objects
@ -210,7 +210,6 @@ class LoadImages:
self.vid_stride = vid_stride # video frame-rate stride self.vid_stride = vid_stride # video frame-rate stride
self.bs = 1 self.bs = 1
if any(videos): if any(videos):
self.orientation = None # rotation degrees
self._new_video(videos[0]) # new video self._new_video(videos[0]) # new video
else: else:
self.cap = None self.cap = None
@ -263,20 +262,6 @@ class LoadImages:
self.frame = 0 self.frame = 0
self.cap = cv2.VideoCapture(path) self.cap = cv2.VideoCapture(path)
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride) self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
if hasattr(cv2, 'CAP_PROP_ORIENTATION_META'): # cv2<4.6.0 compatibility
self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
# Disable auto-orientation due to known issues in https://github.com/ultralytics/yolov5/issues/8493
# self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0)
def _cv2_rotate(self, im):
"""Rotate a cv2 video manually."""
if self.orientation == 0:
return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
elif self.orientation == 180:
return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
elif self.orientation == 90:
return cv2.rotate(im, cv2.ROTATE_180)
return im
def __len__(self): def __len__(self):
"""Returns the number of files in the object.""" """Returns the number of files in the object."""
@ -385,10 +370,10 @@ def autocast_list(source):
return files return files
LOADERS = [LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots] LOADERS = LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots # tuple
def get_best_youtube_url(url, use_pafy=True): def get_best_youtube_url(url, use_pafy=False):
""" """
Retrieves the URL of the best quality MP4 video stream from a given YouTube video. Retrieves the URL of the best quality MP4 video stream from a given YouTube video.
@ -411,9 +396,11 @@ def get_best_youtube_url(url, use_pafy=True):
import yt_dlp import yt_dlp
with yt_dlp.YoutubeDL({'quiet': True}) as ydl: with yt_dlp.YoutubeDL({'quiet': True}) as ydl:
info_dict = ydl.extract_info(url, download=False) # extract info info_dict = ydl.extract_info(url, download=False) # extract info
for f in info_dict.get('formats', None): for f in reversed(info_dict.get('formats', [])): # reversed because best is usually last
if f['vcodec'] != 'none' and f['acodec'] == 'none' and f['ext'] == 'mp4' and f.get('width') > 1280: # Find a format with video codec, no audio, *.mp4 extension at least 1920x1080 size
return f.get('url', None) good_size = (f.get('width') or 0) >= 1920 or (f.get('height') or 0) >= 1080
if good_size and f['vcodec'] != 'none' and f['acodec'] == 'none' and f['ext'] == 'mp4':
return f.get('url')
if __name__ == '__main__': if __name__ == '__main__':

@ -142,16 +142,12 @@ def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
downsample_ratio (int): downsample ratio downsample_ratio (int): downsample ratio
""" """
mask = np.zeros(imgsz, dtype=np.uint8) mask = np.zeros(imgsz, dtype=np.uint8)
polygons = np.asarray(polygons) polygons = np.asarray(polygons, dtype=np.int32)
polygons = polygons.astype(np.int32) polygons = polygons.reshape((polygons.shape[0], -1, 2))
shape = polygons.shape
polygons = polygons.reshape(shape[0], -1, 2)
cv2.fillPoly(mask, polygons, color=color) cv2.fillPoly(mask, polygons, color=color)
nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio) nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)
# NOTE: fillPoly firstly then resize is trying the keep the same way # NOTE: fillPoly first then resize is trying to keep the same way of loss calculation when mask-ratio=1.
# of loss calculation when mask-ratio=1. return cv2.resize(mask, (nw, nh))
mask = cv2.resize(mask, (nw, nh))
return mask
def polygons2masks(imgsz, polygons, color, downsample_ratio=1): def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
@ -162,11 +158,7 @@ def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
color (int): color color (int): color
downsample_ratio (int): downsample ratio downsample_ratio (int): downsample ratio
""" """
masks = [] return np.array([polygon2mask(imgsz, [x.reshape(-1)], color, downsample_ratio) for x in polygons])
for si in range(len(polygons)):
mask = polygon2mask(imgsz, [polygons[si].reshape(-1)], color, downsample_ratio)
masks.append(mask)
return np.array(masks)
def polygons2masks_overlap(imgsz, segments, downsample_ratio=1): def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
@ -421,7 +413,7 @@ class HUBDatasetStats:
else: else:
raise ValueError('Undefined dataset task.') raise ValueError('Undefined dataset task.')
zipped = zip(labels['cls'], coordinates) zipped = zip(labels['cls'], coordinates)
return [[int(c), *(round(float(x), 4) for x in points)] for c, points in zipped] return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped]
for split in 'train', 'val', 'test': for split in 'train', 'val', 'test':
if self.data.get(split) is None: if self.data.get(split) is None:
@ -563,7 +555,7 @@ def zip_directory(dir, use_zipfile_library=True):
def autosplit(path=DATASETS_DIR / 'coco8/images', weights=(0.9, 0.1, 0.0), annotated_only=False): def autosplit(path=DATASETS_DIR / 'coco8/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
""" """
Autosplit a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files. Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
Args: Args:
path (Path, optional): Path to images directory. Defaults to DATASETS_DIR / 'coco8/images'. path (Path, optional): Path to images directory. Defaults to DATASETS_DIR / 'coco8/images'.

@ -249,11 +249,11 @@ class Exporter:
f[4], _ = self.export_coreml() f[4], _ = self.export_coreml()
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
self.args.int8 |= edgetpu self.args.int8 |= edgetpu
f[5], s_model = self.export_saved_model() f[5], keras_model = self.export_saved_model()
if pb or tfjs: # pb prerequisite to tfjs if pb or tfjs: # pb prerequisite to tfjs
f[6], _ = self.export_pb(s_model) f[6], _ = self.export_pb(keras_model=keras_model)
if tflite: if tflite:
f[7], _ = self.export_tflite(s_model, nms=False, agnostic_nms=self.args.agnostic_nms) f[7], _ = self.export_tflite(keras_model=keras_model, nms=False, agnostic_nms=self.args.agnostic_nms)
if edgetpu: if edgetpu:
f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f'{self.file.stem}_full_integer_quant.tflite') f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f'{self.file.stem}_full_integer_quant.tflite')
if tfjs: if tfjs:
@ -671,10 +671,7 @@ class Exporter:
for file in f.rglob('*.tflite'): for file in f.rglob('*.tflite'):
f.unlink() if 'quant_with_int16_act.tflite' in str(f) else self._add_tflite_metadata(file) f.unlink() if 'quant_with_int16_act.tflite' in str(f) else self._add_tflite_metadata(file)
# Load saved_model return str(f), tf.saved_model.load(f, tags=None, options=None) # load saved_model as Keras model
keras_model = tf.saved_model.load(f, tags=None, options=None)
return str(f), keras_model
@try_export @try_export
def export_pb(self, keras_model, prefix=colorstr('TensorFlow GraphDef:')): def export_pb(self, keras_model, prefix=colorstr('TensorFlow GraphDef:')):

@ -81,7 +81,7 @@ class BaseTrainer:
overrides (dict, optional): Configuration overrides. Defaults to None. overrides (dict, optional): Configuration overrides. Defaults to None.
""" """
self.args = get_cfg(cfg, overrides) self.args = get_cfg(cfg, overrides)
self.check_resume() self.check_resume(overrides)
self.device = select_device(self.args.device, self.args.batch) self.device = select_device(self.args.device, self.args.batch)
self.validator = None self.validator = None
self.model = None self.model = None
@ -576,7 +576,7 @@ class BaseTrainer:
self.metrics.pop('fitness', None) self.metrics.pop('fitness', None)
self.run_callbacks('on_fit_epoch_end') self.run_callbacks('on_fit_epoch_end')
def check_resume(self): def check_resume(self, overrides):
"""Check if resume checkpoint exists and update arguments accordingly.""" """Check if resume checkpoint exists and update arguments accordingly."""
resume = self.args.resume resume = self.args.resume
if resume: if resume:
@ -589,8 +589,13 @@ class BaseTrainer:
if not Path(ckpt_args['data']).exists(): if not Path(ckpt_args['data']).exists():
ckpt_args['data'] = self.args.data ckpt_args['data'] = self.args.data
resume = True
self.args = get_cfg(ckpt_args) self.args = get_cfg(ckpt_args)
self.args.model, resume = str(last), True # reinstate self.args.model = str(last) # reinstate model
for k in 'imgsz', 'batch': # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
if k in overrides:
setattr(self.args, k, overrides[k])
except Exception as e: except Exception as e:
raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, ' raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '
"i.e. 'yolo train resume model=path/to/last.pt'") from e "i.e. 'yolo train resume model=path/to/last.pt'") from e

@ -0,0 +1 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license

@ -18,7 +18,18 @@ except (ImportError, AssertionError, AttributeError):
def linear_assignment(cost_matrix, thresh, use_lap=True): def linear_assignment(cost_matrix, thresh, use_lap=True):
"""Linear assignment implementations with scipy and lap.lapjv.""" """
Perform linear assignment using scipy or lap.lapjv.
Args:
cost_matrix (np.ndarray): The matrix containing cost values for assignments.
thresh (float): Threshold for considering an assignment valid.
use_lap (bool, optional): Whether to use lap.lapjv. Defaults to True.
Returns:
(tuple): Tuple containing matched indices, unmatched indices from 'a', and unmatched indices from 'b'.
"""
if cost_matrix.size == 0: if cost_matrix.size == 0:
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1])) return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
@ -42,11 +53,14 @@ def linear_assignment(cost_matrix, thresh, use_lap=True):
def iou_distance(atracks, btracks): def iou_distance(atracks, btracks):
""" """
Compute cost based on IoU Compute cost based on Intersection over Union (IoU) between tracks.
:type atracks: list[STrack]
:type btracks: list[STrack] Args:
atracks (list[STrack] | list[np.ndarray]): List of tracks 'a' or bounding boxes.
btracks (list[STrack] | list[np.ndarray]): List of tracks 'b' or bounding boxes.
:rtype cost_matrix np.ndarray Returns:
(np.ndarray): Cost matrix computed based on IoU.
""" """
if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) \ if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) \
@ -67,10 +81,15 @@ def iou_distance(atracks, btracks):
def embedding_distance(tracks, detections, metric='cosine'): def embedding_distance(tracks, detections, metric='cosine'):
""" """
:param tracks: list[STrack] Compute distance between tracks and detections based on embeddings.
:param detections: list[BaseTrack]
:param metric: Args:
:return: cost_matrix np.ndarray tracks (list[STrack]): List of tracks.
detections (list[BaseTrack]): List of detections.
metric (str, optional): Metric for distance computation. Defaults to 'cosine'.
Returns:
(np.ndarray): Cost matrix computed based on embeddings.
""" """
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32) cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
@ -85,7 +104,17 @@ def embedding_distance(tracks, detections, metric='cosine'):
def fuse_score(cost_matrix, detections): def fuse_score(cost_matrix, detections):
"""Fuses cost matrix with detection scores to produce a single similarity matrix.""" """
Fuses cost matrix with detection scores to produce a single similarity matrix.
Args:
cost_matrix (np.ndarray): The matrix containing cost values for assignments.
detections (list[BaseTrack]): List of detections with scores.
Returns:
(np.ndarray): Fused similarity matrix.
"""
if cost_matrix.size == 0: if cost_matrix.size == 0:
return cost_matrix return cost_matrix
iou_sim = 1 - cost_matrix iou_sim = 1 - cost_matrix

Loading…
Cancel
Save