`8.0.60` new HUB training syntax (#1753)

Co-authored-by: Rafael Pierre <97888102+rafaelvp-db@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Semih Demirel <85176438+semihhdemirel@users.noreply.github.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent e7876e1ba9
commit 84948651cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,7 +7,7 @@ on:
push: push:
branches: [main] branches: [main]
pull_request: pull_request:
branches: [main] branches: [main, updates]
schedule: schedule:
- cron: '0 0 * * *' # runs at 00:00 UTC every day - cron: '0 0 * * *' # runs at 00:00 UTC every day
@ -43,16 +43,36 @@ jobs:
python --version python --version
pip --version pip --version
pip list pip list
- name: Test HUB training - name: Test HUB training (Python Usage 1)
shell: python
env:
APIKEY: ${{ secrets.ULTRALYTICS_HUB_APIKEY }}
run: |
import os
from pathlib import Path
from ultralytics import YOLO, hub
from ultralytics.yolo.utils import USER_CONFIG_DIR
Path(USER_CONFIG_DIR / 'settings.yaml').unlink()
key = os.environ['APIKEY']
hub.reset_model(key)
model = YOLO('https://hub.ultralytics.com/models/' + key)
model.train()
- name: Test HUB training (Python Usage 2)
shell: python shell: python
env: env:
APIKEY: ${{ secrets.ULTRALYTICS_HUB_APIKEY }} APIKEY: ${{ secrets.ULTRALYTICS_HUB_APIKEY }}
run: | run: |
import os import os
from ultralytics import hub from pathlib import Path
from ultralytics import YOLO, hub
from ultralytics.yolo.utils import USER_CONFIG_DIR
Path(USER_CONFIG_DIR / 'settings.yaml').unlink()
key = os.environ['APIKEY'] key = os.environ['APIKEY']
hub.reset_model(key) hub.reset_model(key)
hub.start(key) key, model_id = key.split('_')
hub.login(key)
model = YOLO('https://hub.ultralytics.com/models/' + model_id)
model.train()
Benchmarks: Benchmarks:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}

@ -26,6 +26,7 @@ WORKDIR /usr/src/ultralytics
# Copy contents # Copy contents
# COPY . /usr/src/app (issues as not a .git directory) # COPY . /usr/src/app (issues as not a .git directory)
RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics
ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /usr/src/ultralytics/
# Install pip packages # Install pip packages
RUN python3 -m pip install --upgrade pip wheel RUN python3 -m pip install --upgrade pip wheel

@ -22,6 +22,7 @@ WORKDIR /usr/src/ultralytics
# Copy contents # Copy contents
# COPY . /usr/src/app (issues as not a .git directory) # COPY . /usr/src/app (issues as not a .git directory)
RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics
ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /usr/src/ultralytics/
# Install pip packages # Install pip packages
RUN python3 -m pip install --upgrade pip wheel RUN python3 -m pip install --upgrade pip wheel

@ -22,6 +22,7 @@ WORKDIR /usr/src/ultralytics
# Copy contents # Copy contents
# COPY . /usr/src/app (issues as not a .git directory) # COPY . /usr/src/app (issues as not a .git directory)
RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics
ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /usr/src/ultralytics/
# Install pip packages # Install pip packages
RUN python3 -m pip install --upgrade pip wheel RUN python3 -m pip install --upgrade pip wheel

@ -17,7 +17,7 @@ passing `stream=True` in the predictor's call method.
probs = result.probs # Class probabilities for classification outputs probs = result.probs # Class probabilities for classification outputs
``` ```
=== "Return a list with `Stream=True`" === "Return a generator with `Stream=True`"
```python ```python
inputs = [img, img] # list of numpy arrays inputs = [img, img] # list of numpy arrays
results = model(inputs, stream=True) # generator of Results objects results = model(inputs, stream=True) # generator of Results objects
@ -54,6 +54,40 @@ whether each source can be used in streaming mode with `stream=True` ✅ and an
| YouTube ✅ | `'https://youtu.be/Zgi9g1ksQHc'` | `str` | | | YouTube ✅ | `'https://youtu.be/Zgi9g1ksQHc'` | `str` | |
| stream ✅ | `'rtsp://example.com/media.mp4'` | `str` | RTSP, RTMP, HTTP | | stream ✅ | `'rtsp://example.com/media.mp4'` | `str` | RTSP, RTMP, HTTP |
## Arguments
`model.predict` accepts multiple arguments that control the predction operation. These arguments can be passed directly to `model.predict`:
!!! example
```
model.predict(source, save=True, imgsz=320, conf=0.5)
```
All supported arguments:
| Key | Value | Description |
|------------------|------------------------|----------------------------------------------------------|
| `source` | `'ultralytics/assets'` | source directory for images or videos |
| `conf` | `0.25` | object confidence threshold for detection |
| `iou` | `0.7` | intersection over union (IoU) threshold for NMS |
| `half` | `False` | use half precision (FP16) |
| `device` | `None` | device to run on, i.e. cuda device=0/1/2/3 or device=cpu |
| `show` | `False` | show results if possible |
| `save` | `False` | save images with results |
| `save_txt` | `False` | save results as .txt file |
| `save_conf` | `False` | save results with confidence scores |
| `save_crop` | `False` | save cropped images with results |
| `hide_labels` | `False` | hide labels |
| `hide_conf` | `False` | hide confidence scores |
| `max_det` | `300` | maximum number of detections per image |
| `vid_stride` | `False` | video frame-rate stride |
| `line_thickness` | `3` | bounding box thickness (pixels) |
| `visualize` | `False` | visualize model features |
| `augment` | `False` | apply image augmentation to prediction sources |
| `agnostic_nms` | `False` | class-agnostic NMS |
| `retina_masks` | `False` | use high-resolution segmentation masks |
| `classes` | `None` | filter results by class, i.e. class=0, or class=[0,2,3] |
| `boxes` | `True` | Show boxes in segmentation predictions |
## Image and Video Formats ## Image and Video Formats
YOLOv8 supports various image and video formats, as specified YOLOv8 supports various image and video formats, as specified

@ -96,7 +96,6 @@ names:
77: teddy bear 77: teddy bear
78: hair drier 78: hair drier
79: toothbrush 79: toothbrush
``` ```

@ -6,7 +6,7 @@ Just simply clone and run
```bash ```bash
pip install -r requirements.txt pip install -r requirements.txt
python main.py python main.py --model yolov8n.onnx --img image.jpg
``` ```
If you start from scratch: If you start from scratch:

@ -1,3 +1,5 @@
import argparse
import cv2.dnn import cv2.dnn
import numpy as np import numpy as np
@ -16,9 +18,9 @@ def draw_bounding_box(img, class_id, confidence, x, y, x_plus_w, y_plus_h):
cv2.putText(img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) cv2.putText(img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
def main(): def main(onnx_model, input_image):
model: cv2.dnn.Net = cv2.dnn.readNetFromONNX('yolov8n.onnx') model: cv2.dnn.Net = cv2.dnn.readNetFromONNX(onnx_model)
original_image: np.ndarray = cv2.imread(str(ROOT / 'assets/bus.jpg')) original_image: np.ndarray = cv2.imread(input_image)
[height, width, _] = original_image.shape [height, width, _] = original_image.shape
length = max((height, width)) length = max((height, width))
image = np.zeros((length, length, 3), np.uint8) image = np.zeros((length, length, 3), np.uint8)
@ -71,4 +73,8 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
main() parser = argparse.ArgumentParser()
parser.add_argument('--model', default='yolov8n.onnx', help='Input your onnx model.')
parser.add_argument('--img', default=str(ROOT / 'assets/bus.jpg'), help='Path to input image.')
args = parser.parse_args()
main(args.model, args.img)

@ -46,7 +46,7 @@ theme:
- content.tabs.link # all code tabs change simultaneously - content.tabs.link # all code tabs change simultaneously
# Customization # Customization
copyright: Ultralytics 2023. All rights reserved. copyright: <a href="https://ultralytics.com" target="_blank">Ultralytics 2023.</a> All rights reserved.
extra: extra:
# version: # version:
# provider: mike # version drop-down menu # provider: mike # version drop-down menu
@ -167,7 +167,7 @@ nav:
- Hyperparameter evolution: yolov5/hyp_evolution.md - Hyperparameter evolution: yolov5/hyp_evolution.md
- Transfer learning with frozen layers: yolov5/transfer_learn_frozen.md - Transfer learning with frozen layers: yolov5/transfer_learn_frozen.md
- Architecture Summary: yolov5/architecture.md - Architecture Summary: yolov5/architecture.md
- Roboflow for Datasets, Labeling, and Active Learning: yolov5/roboflow.md - Roboflow Datasets: yolov5/roboflow.md
- Neural Magic's DeepSparse: yolov5/neural_magic.md - Neural Magic's DeepSparse: yolov5/neural_magic.md
- Comet Logging: yolov5/comet.md - Comet Logging: yolov5/comet.md
- Clearml Logging: yolov5/clearml.md - Clearml Logging: yolov5/clearml.md

@ -58,7 +58,7 @@ setup(
'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Scientific/Engineering :: Image Recognition', 'Topic :: Scientific/Engineering :: Image Recognition',
'Operating System :: POSIX :: Linux', 'Operating System :: POSIX :: Linux',
'Operating System :: macOS', 'Operating System :: MacOS',
'Operating System :: Microsoft :: Windows', ], 'Operating System :: Microsoft :: Windows', ],
keywords='machine-learning, deep-learning, vision, ML, DL, AI, YOLO, YOLOv3, YOLOv5, YOLOv8, HUB, Ultralytics', keywords='machine-learning, deep-learning, vision, ML, DL, AI, YOLO, YOLOv3, YOLOv5, YOLOv8, HUB, Ultralytics',
entry_points={ entry_points={

@ -56,11 +56,11 @@ def test_predict_detect():
def test_predict_segment(): def test_predict_segment():
run(f"yolo predict model={MODEL}-seg.pt source={ROOT / 'assets'} imgsz=32 save") run(f"yolo predict model={MODEL}-seg.pt source={ROOT / 'assets'} imgsz=32 save save_txt")
def test_predict_classify(): def test_predict_classify():
run(f"yolo predict model={MODEL}-cls.pt source={ROOT / 'assets'} imgsz=32 save") run(f"yolo predict model={MODEL}-cls.pt source={ROOT / 'assets'} imgsz=32 save save_txt")
# Export checks -------------------------------------------------------------------------------------------------------- # Export checks --------------------------------------------------------------------------------------------------------

@ -1,8 +1,9 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = '8.0.59' __version__ = '8.0.60'
from ultralytics.hub import start
from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils.checks import check_yolo as checks from ultralytics.yolo.utils.checks import check_yolo as checks
__all__ = '__version__', 'YOLO', 'checks' # allow simpler import __all__ = '__version__', 'YOLO', 'checks', 'start' # allow simpler import

@ -2,47 +2,51 @@
import requests import requests
from ultralytics.hub.auth import Auth
from ultralytics.hub.session import HUBTrainingSession
from ultralytics.hub.utils import PREFIX, split_key from ultralytics.hub.utils import PREFIX, split_key
from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.utils import LOGGER
from ultralytics.yolo.utils import LOGGER, emojis
def start(key=''): def login(api_key=''):
"""
Start training models with Ultralytics HUB. Usage: from ultralytics.hub import start; start('API_KEY')
""" """
auth = Auth(key) Log in to the Ultralytics HUB API using the provided API key.
model_id = split_key(key)[1] if auth.get_state() else request_api_key(auth)
if not model_id:
raise ConnectionError(emojis('Connecting with global API key is not currently supported. ❌'))
session = HUBTrainingSession(model_id=model_id, auth=auth) Args:
session.check_disk_space() api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
model = YOLO(model=session.model_file, session=session) Example:
model.train(**session.train_args) from ultralytics import hub
hub.login('your_api_key')
"""
from ultralytics.hub.auth import Auth
Auth(api_key)
def request_api_key(auth, max_attempts=3): def logout():
""" """
Prompt the user to input their API key. Returns the model ID. Logout Ultralytics HUB
Example:
from ultralytics import hub
hub.logout()
""" """
import getpass LOGGER.warning('WARNING ⚠️ This method is not yet implemented.')
for attempts in range(max_attempts):
LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}')
input_key = getpass.getpass(
'Enter your Ultralytics API Key from https://hub.ultralytics.com/settings?tab=api+keys:\n')
auth.api_key, model_id = split_key(input_key)
if auth.authenticate(): def start(key=''):
LOGGER.info(f'{PREFIX}Authenticated ✅') """
return model_id Start training models with Ultralytics HUB (DEPRECATED).
Args:
key (str, optional): A string containing either the API key and model ID combination (apikey_modelid),
or the full model URL (https://hub.ultralytics.com/models/apikey_modelid).
"""
LOGGER.warning(f"""
WARNING ultralytics.start() is deprecated in 8.0.60. Updated usage to train your Ultralytics HUB model is below:
LOGGER.warning(f'{PREFIX}Invalid API key ⚠️\n') from ultralytics import YOLO
raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌')) model = YOLO('https://hub.ultralytics.com/models/{key}')
model.train()""")
def reset_model(key=''): def reset_model(key=''):

@ -2,27 +2,74 @@
import requests import requests
from ultralytics.hub.utils import HUB_API_ROOT, request_with_credentials from ultralytics.hub.utils import HUB_API_ROOT, PREFIX, request_with_credentials
from ultralytics.yolo.utils import is_colab from ultralytics.yolo.utils import LOGGER, SETTINGS, emojis, is_colab, set_settings
API_KEY_PATH = 'https://hub.ultralytics.com/settings?tab=api+keys' API_KEY_URL = 'https://hub.ultralytics.com/settings?tab=api+keys'
class Auth: class Auth:
id_token = api_key = model_key = False id_token = api_key = model_key = False
def __init__(self, api_key=None): def __init__(self, api_key=''):
self.api_key = self._clean_api_key(api_key) """
self.authenticate() if self.api_key else self.auth_with_cookies() Initialize the Auth class with an optional API key.
Args:
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
"""
# Split the input API key in case it contains a combined key_model and keep only the API key part
api_key = api_key.split('_')[0]
# Set API key attribute as value passed or SETTINGS API key if none passed
self.api_key = api_key or SETTINGS.get('api_key', '')
# If an API key is provided
if self.api_key:
# If the provided API key matches the API key in the SETTINGS
if self.api_key == SETTINGS.get('api_key'):
# Log that the user is already logged in
LOGGER.info(f'{PREFIX}Authenticated ✅')
return
else:
# Attempt to authenticate with the provided API key
success = self.authenticate()
# If the API key is not provided and the environment is a Google Colab notebook
elif is_colab():
# Attempt to authenticate using browser cookies
success = self.auth_with_cookies()
else:
# Request an API key
success = self.request_api_key()
@staticmethod # Update SETTINGS with the new API key after successful authentication
def _clean_api_key(key: str) -> str: if success:
"""Strip model from key if present""" set_settings({'api_key': self.api_key})
separator = '_' # Log that the new login was successful
return key.split(separator)[0] if separator in key else key LOGGER.info(f'{PREFIX}New authentication successful ✅')
else:
LOGGER.info(f'{PREFIX}Retrieve API key from {API_KEY_URL}')
def request_api_key(self, max_attempts=3):
"""
Prompt the user to input their API key. Returns the model ID.
"""
import getpass
for attempts in range(max_attempts):
LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}')
input_key = getpass.getpass(f'Enter API key from {API_KEY_URL} ')
self.api_key = input_key.split('_')[0] # remove model id if present
if self.authenticate():
return True
raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌'))
def authenticate(self) -> bool: def authenticate(self) -> bool:
"""Attempt to authenticate with server""" """
Attempt to authenticate with the server using either id_token or API key.
Returns:
bool: True if authentication is successful, False otherwise.
"""
try: try:
header = self.get_auth_header() header = self.get_auth_header()
if header: if header:
@ -33,12 +80,16 @@ class Auth:
raise ConnectionError('User has not authenticated locally.') raise ConnectionError('User has not authenticated locally.')
except ConnectionError: except ConnectionError:
self.id_token = self.api_key = False # reset invalid self.id_token = self.api_key = False # reset invalid
LOGGER.warning(f'{PREFIX}Invalid API key ⚠️')
return False return False
def auth_with_cookies(self) -> bool: def auth_with_cookies(self) -> bool:
""" """
Attempt to fetch authentication via cookies and set id_token. Attempt to fetch authentication via cookies and set id_token.
User must be logged in to HUB and running in a supported browser. User must be logged in to HUB and running in a supported browser.
Returns:
bool: True if authentication is successful, False otherwise.
""" """
if not is_colab(): if not is_colab():
return False # Currently only works with Colab return False # Currently only works with Colab
@ -54,6 +105,12 @@ class Auth:
return False return False
def get_auth_header(self): def get_auth_header(self):
"""
Get the authentication header for making API requests.
Returns:
dict: The authentication header if id_token or API key is set, None otherwise.
"""
if self.id_token: if self.id_token:
return {'authorization': f'Bearer {self.id_token}'} return {'authorization': f'Bearer {self.id_token}'}
elif self.api_key: elif self.api_key:
@ -62,9 +119,19 @@ class Auth:
return None return None
def get_state(self) -> bool: def get_state(self) -> bool:
"""Get the authentication state""" """
Get the authentication state.
Returns:
bool: True if either id_token or API key is set, False otherwise.
"""
return self.id_token or self.api_key return self.id_token or self.api_key
def set_api_key(self, key: str): def set_api_key(self, key: str):
"""Get the authentication state""" """
Set the API key for authentication.
Args:
key (str): The API key string.
"""
self.api_key = key self.api_key = key

@ -6,17 +6,62 @@ from time import sleep
import requests import requests
from ultralytics.hub.utils import HUB_API_ROOT, check_dataset_disk_space, smart_request from ultralytics.hub.utils import HUB_API_ROOT, PREFIX, check_dataset_disk_space, smart_request
from ultralytics.yolo.utils import LOGGER, PREFIX, __version__, checks, emojis, is_colab, threaded from ultralytics.yolo.utils import LOGGER, __version__, checks, emojis, is_colab, threaded
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local' AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
class HUBTrainingSession: class HUBTrainingSession:
"""
HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
Args:
url (str): Model identifier used to initialize the HUB training session.
Attributes:
agent_id (str): Identifier for the instance communicating with the server.
model_id (str): Identifier for the YOLOv5 model being trained.
model_url (str): URL for the model in Ultralytics HUB.
api_url (str): API URL for the model in Ultralytics HUB.
auth_header (Dict): Authentication header for the Ultralytics HUB API requests.
rate_limits (Dict): Rate limits for different API calls (in seconds).
timers (Dict): Timers for rate limiting.
metrics_queue (Dict): Queue for the model's metrics.
model (Dict): Model data fetched from Ultralytics HUB.
alive (bool): Indicates if the heartbeat loop is active.
"""
def __init__(self, url):
"""
Initialize the HUBTrainingSession with the provided model identifier.
Args:
url (str): Model identifier used to initialize the HUB training session.
It can be a URL string or a model key with specific format.
Raises:
ValueError: If the provided model identifier is invalid.
ConnectionError: If connecting with global API key is not supported.
"""
from ultralytics.hub.auth import Auth
def __init__(self, model_id, auth): # Parse input
if url.startswith('https://hub.ultralytics.com/models/'):
url = url.split('https://hub.ultralytics.com/models/')[-1]
if [len(x) for x in url.split('_')] == [42, 20]:
key, model_id = url.split('_')
elif len(url) == 20:
key, model_id = '', url
else:
raise ValueError(f'Invalid HUBTrainingSession input: {url}')
# Authorize
auth = Auth(key)
self.agent_id = None # identifies which instance is communicating with server self.agent_id = None # identifies which instance is communicating with server
self.model_id = model_id self.model_id = model_id
self.model_url = f'https://hub.ultralytics.com/models/{model_id}'
self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}' self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
self.auth_header = auth.get_auth_header() self.auth_header = auth.get_auth_header()
self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds) self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
@ -26,16 +71,17 @@ class HUBTrainingSession:
self.alive = True self.alive = True
self._start_heartbeat() # start heartbeats self._start_heartbeat() # start heartbeats
self._register_signal_handlers() self._register_signal_handlers()
LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
def _register_signal_handlers(self): def _register_signal_handlers(self):
"""Register signal handlers for SIGTERM and SIGINT signals to gracefully handle termination."""
signal.signal(signal.SIGTERM, self._handle_signal) signal.signal(signal.SIGTERM, self._handle_signal)
signal.signal(signal.SIGINT, self._handle_signal) signal.signal(signal.SIGINT, self._handle_signal)
def _handle_signal(self, signum, frame): def _handle_signal(self, signum, frame):
""" """
Prevent heartbeats from being sent on Colab after kill. Handle kill signals and prevent heartbeats from being sent on Colab after termination.
This method does not use frame, it is included as it is This method does not use frame, it is included as it is passed by signal.
passed by signal.
""" """
if self.alive is True: if self.alive is True:
LOGGER.info(f'{PREFIX}Kill signal received! ❌') LOGGER.info(f'{PREFIX}Kill signal received! ❌')
@ -43,15 +89,16 @@ class HUBTrainingSession:
sys.exit(signum) sys.exit(signum)
def _stop_heartbeat(self): def _stop_heartbeat(self):
"""End the heartbeat loop""" """Terminate the heartbeat loop."""
self.alive = False self.alive = False
def upload_metrics(self): def upload_metrics(self):
"""Upload model metrics to Ultralytics HUB."""
payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'} payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'}
smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2) smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2)
def _get_model(self): def _get_model(self):
# Returns model from database by id """Fetch and return model data from Ultralytics HUB."""
api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}' api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
try: try:
@ -59,9 +106,7 @@ class HUBTrainingSession:
data = response.json().get('data', None) data = response.json().get('data', None)
if data.get('status', None) == 'trained': if data.get('status', None) == 'trained':
raise ValueError( raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀'))
emojis(f'Model is already trained and uploaded to '
f'https://hub.ultralytics.com/models/{self.model_id} 🚀'))
if not data.get('data', None): if not data.get('data', None):
raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
@ -88,11 +133,21 @@ class HUBTrainingSession:
raise raise
def check_disk_space(self): def check_disk_space(self):
if not check_dataset_disk_space(self.model['data']): """Check if there is enough disk space for the dataset."""
if not check_dataset_disk_space(url=self.model['data']):
raise MemoryError('Not enough disk space') raise MemoryError('Not enough disk space')
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False): def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
# Upload a model to HUB """
Upload a model checkpoint to Ultralytics HUB.
Args:
epoch (int): The current training epoch.
weights (str): Path to the model weights file.
is_best (bool): Indicates if the current model is the best one so far.
map (float): Mean average precision of the model.
final (bool): Indicates if the model is the final model after training.
"""
if Path(weights).is_file(): if Path(weights).is_file():
with open(weights, 'rb') as f: with open(weights, 'rb') as f:
file = f.read() file = f.read()
@ -120,6 +175,7 @@ class HUBTrainingSession:
@threaded @threaded
def _start_heartbeat(self): def _start_heartbeat(self):
"""Begin a threaded heartbeat loop to report the agent's status to Ultralytics HUB."""
while self.alive: while self.alive:
r = smart_request('post', r = smart_request('post',
f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}', f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',

@ -22,7 +22,16 @@ HUB_API_ROOT = os.environ.get('ULTRALYTICS_HUB_API', 'https://api.ultralytics.co
def check_dataset_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=2.0): def check_dataset_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=2.0):
# Check that url fits on disk with safety factor sf, i.e. require 2GB free if url size is 1GB with sf=2.0 """
Check if there is sufficient disk space to download and store a dataset.
Args:
url (str, optional): The URL to the dataset file. Defaults to 'https://ultralytics.com/assets/coco128.zip'.
sf (float, optional): Safety factor, the multiplier for the required free space. Defaults to 2.0.
Returns:
bool: True if there is sufficient disk space, False otherwise.
"""
gib = 1 << 30 # bytes per GiB gib = 1 << 30 # bytes per GiB
data = int(requests.head(url).headers['Content-Length']) / gib # dataset size (GB) data = int(requests.head(url).headers['Content-Length']) / gib # dataset size (GB)
total, used, free = (x / gib for x in shutil.disk_usage('/')) # bytes total, used, free = (x / gib for x in shutil.disk_usage('/')) # bytes
@ -35,7 +44,18 @@ def check_dataset_disk_space(url='https://ultralytics.com/assets/coco128.zip', s
def request_with_credentials(url: str) -> any: def request_with_credentials(url: str) -> any:
""" Make an ajax request with cookies attached """ """
Make an AJAX request with cookies attached in a Google Colab environment.
Args:
url (str): The URL to make the request to.
Returns:
any: The response data from the AJAX request.
Raises:
OSError: If the function is not run in a Google Colab environment.
"""
if not is_colab(): if not is_colab():
raise OSError('request_with_credentials() must run in a Colab environment') raise OSError('request_with_credentials() must run in a Colab environment')
from google.colab import output # noqa from google.colab import output # noqa
@ -95,7 +115,6 @@ def requests_with_progress(method, url, **kwargs):
Returns: Returns:
requests.Response: The response from the HTTP request. requests.Response: The response from the HTTP request.
""" """
progress = kwargs.pop('progress', False) progress = kwargs.pop('progress', False)
if not progress: if not progress:
@ -126,7 +145,6 @@ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbos
Returns: Returns:
requests.Response: The HTTP response object. If the request is executed in a separate thread, returns None. requests.Response: The HTTP response object. If the request is executed in a separate thread, returns None.
""" """
retry_codes = (408, 500) # retry only these codes retry_codes = (408, 500) # retry only these codes
@ -171,8 +189,8 @@ class Traces:
def __init__(self): def __init__(self):
""" """
Initialize Traces for error tracking and reporting if tests are not currently running. Initialize Traces for error tracking and reporting if tests are not currently running.
Sets the rate limit, timer, and metadata attributes, and determines whether Traces are enabled.
""" """
from ultralytics.yolo.cfg import MODES, TASKS
self.rate_limit = 60.0 # rate limit (seconds) self.rate_limit = 60.0 # rate limit (seconds)
self.t = 0.0 # rate limit timer (seconds) self.t = 0.0 # rate limit timer (seconds)
self.metadata = { self.metadata = {
@ -187,17 +205,22 @@ class Traces:
not TESTS_RUNNING and \ not TESTS_RUNNING and \
ONLINE and \ ONLINE and \
(is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git') (is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git')
self.usage = {'tasks': {k: 0 for k in TASKS}, 'modes': {k: 0 for k in MODES}} self._reset_usage()
def __call__(self, cfg, all_keys=False, traces_sample_rate=1.0): def __call__(self, cfg, all_keys=False, traces_sample_rate=1.0):
""" """
Sync traces data if enabled in the global settings Sync traces data if enabled in the global settings.
Args: Args:
cfg (IterableSimpleNamespace): Configuration for the task and mode. cfg (IterableSimpleNamespace): Configuration for the task and mode.
all_keys (bool): Sync all items, not just non-default values. all_keys (bool): Sync all items, not just non-default values.
traces_sample_rate (float): Fraction of traces captured from 0.0 to 1.0 traces_sample_rate (float): Fraction of traces captured from 0.0 to 1.0.
""" """
# Increment usage
self.usage['modes'][cfg.mode] = self.usage['modes'].get(cfg.mode, 0) + 1
self.usage['tasks'][cfg.task] = self.usage['tasks'].get(cfg.task, 0) + 1
t = time.time() # current time t = time.time() # current time
if not self.enabled or random() > traces_sample_rate: if not self.enabled or random() > traces_sample_rate:
# Traces disabled or not randomly selected, do nothing # Traces disabled or not randomly selected, do nothing
@ -207,18 +230,20 @@ class Traces:
return return
else: else:
# Time is over rate limiter, send trace now # Time is over rate limiter, send trace now
self.t = t # reset rate limit timer trace = {'uuid': SETTINGS['uuid'], 'usage': self.usage.copy(), 'metadata': self.metadata}
# Build trace
if cfg.task in self.usage['tasks']:
self.usage['tasks'][cfg.task] += 1
if cfg.mode in self.usage['modes']:
self.usage['modes'][cfg.mode] += 1
trace = {'uuid': SETTINGS['uuid'], 'usage': self.usage, 'metadata': self.metadata}
# Send a request to the HUB API to sync analytics # Send a request to the HUB API to sync analytics
smart_request('post', f'{HUB_API_ROOT}/v1/usage/anonymous', json=trace, code=3, retry=0, verbose=False) smart_request('post', f'{HUB_API_ROOT}/v1/usage/anonymous', json=trace, code=3, retry=0, verbose=False)
# Reset usage and rate limit timer
self._reset_usage()
self.t = t
def _reset_usage(self):
"""Reset the usage dictionary by initializing keys for each task and mode with a value of 0."""
from ultralytics.yolo.cfg import MODES, TASKS
self.usage = {'tasks': {k: 0 for k in TASKS}, 'modes': {k: 0 for k in MODES}}
# Run below code on hub/utils init ------------------------------------------------------------------------------------- # Run below code on hub/utils init -------------------------------------------------------------------------------------
traces = Traces() traces = Traces()

@ -9,7 +9,8 @@ from types import SimpleNamespace
from typing import Dict, List, Union from typing import Dict, List, Union
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, ROOT, USER_CONFIG_DIR, from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, ROOT, USER_CONFIG_DIR,
IterableSimpleNamespace, __version__, checks, colorstr, yaml_load, yaml_print) IterableSimpleNamespace, __version__, checks, colorstr, get_settings, yaml_load,
yaml_print)
# Define valid tasks and modes # Define valid tasks and modes
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark' MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
@ -187,6 +188,51 @@ def merge_equals_args(args: List[str]) -> List[str]:
return new_args return new_args
def handle_yolo_hub(args: List[str]) -> None:
"""
Handle Ultralytics HUB command-line interface (CLI) commands.
This function processes Ultralytics HUB CLI commands such as login and logout.
It should be called when executing a script with arguments related to HUB authentication.
Args:
args (List[str]): A list of command line arguments
Example:
python my_script.py hub login your_api_key
"""
from ultralytics import hub
if args[0] == 'login':
key = args[1] if len(args) > 1 else ''
# Log in to Ultralytics HUB using the provided API key
hub.login(key)
elif args[0] == 'logout':
# Log out from Ultralytics HUB
hub.logout()
def handle_yolo_settings(args: List[str]) -> None:
"""
Handle YOLO settings command-line interface (CLI) commands.
This function processes YOLO settings CLI commands such as reset.
It should be called when executing a script with arguments related to YOLO settings management.
Args:
args (List[str]): A list of command line arguments for YOLO settings management.
Example:
python my_script.py yolo settings reset
"""
path = USER_CONFIG_DIR / 'settings.yaml' # get SETTINGS YAML file path
if any(args) and args[0] == 'reset':
path.unlink() # delete the settings file
get_settings() # create new settings
LOGGER.info('Settings reset successfully') # inform the user that settings have been reset
yaml_print(path) # print the current settings
def entrypoint(debug=''): def entrypoint(debug=''):
""" """
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
@ -211,8 +257,10 @@ def entrypoint(debug=''):
'help': lambda: LOGGER.info(CLI_HELP_MSG), 'help': lambda: LOGGER.info(CLI_HELP_MSG),
'checks': checks.check_yolo, 'checks': checks.check_yolo,
'version': lambda: LOGGER.info(__version__), 'version': lambda: LOGGER.info(__version__),
'settings': lambda: yaml_print(USER_CONFIG_DIR / 'settings.yaml'), 'settings': lambda: handle_yolo_settings(args[1:]),
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH), 'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
'hub': lambda: handle_yolo_hub(args[1:]),
'login': lambda: handle_yolo_hub(args),
'copy-cfg': copy_default_cfg} 'copy-cfg': copy_default_cfg}
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special} full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
@ -255,8 +303,8 @@ def entrypoint(debug=''):
overrides['task'] = a overrides['task'] = a
elif a in MODES: elif a in MODES:
overrides['mode'] = a overrides['mode'] = a
elif a in special: elif a.lower() in special:
special[a]() special[a.lower()]()
return return
elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool): elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool):
overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True

@ -68,12 +68,14 @@ class YOLO:
list(ultralytics.yolo.engine.results.Results): The prediction results. list(ultralytics.yolo.engine.results.Results): The prediction results.
""" """
def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None, session=None) -> None: def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None:
""" """
Initializes the YOLO model. Initializes the YOLO model.
Args: Args:
model (str, Path): model to load or create model (Union[str, Path], optional): Path or name of the model to load or create. Defaults to 'yolov8n.pt'.
task (Any, optional): Task type for the YOLO model. Defaults to None.
""" """
self._reset_callbacks() self._reset_callbacks()
self.predictor = None # reuse predictor self.predictor = None # reuse predictor
@ -85,10 +87,16 @@ class YOLO:
self.ckpt_path = None self.ckpt_path = None
self.overrides = {} # overrides for trainer object self.overrides = {} # overrides for trainer object
self.metrics = None # validation/training metrics self.metrics = None # validation/training metrics
self.session = session # HUB session self.session = None # HUB session
model = str(model).strip() # strip spaces
# Check if Ultralytics HUB model from https://hub.ultralytics.com
if model.startswith('https://hub.ultralytics.com/models/'):
from ultralytics.hub import HUBTrainingSession
self.session = HUBTrainingSession(model)
model = self.session.model_file
# Load or create new YOLO model # Load or create new YOLO model
model = str(model).strip() # strip spaces
suffix = Path(model).suffix suffix = Path(model).suffix
if not suffix and Path(model).stem in GITHUB_ASSET_STEMS: if not suffix and Path(model).stem in GITHUB_ASSET_STEMS:
model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt
@ -280,6 +288,7 @@ class YOLO:
from ultralytics.yolo.utils.benchmarks import benchmark from ultralytics.yolo.utils.benchmarks import benchmark
overrides = self.model.args.copy() overrides = self.model.args.copy()
overrides.update(kwargs) overrides.update(kwargs)
overrides['mode'] = 'benchmark'
overrides = {**DEFAULT_CFG_DICT, **overrides} # fill in missing overrides keys with defaults overrides = {**DEFAULT_CFG_DICT, **overrides} # fill in missing overrides keys with defaults
return benchmark(model=self, imgsz=overrides['imgsz'], half=overrides['half'], device=overrides['device']) return benchmark(model=self, imgsz=overrides['imgsz'], half=overrides['half'], device=overrides['device'])
@ -293,6 +302,7 @@ class YOLO:
self._check_is_pytorch_model() self._check_is_pytorch_model()
overrides = self.overrides.copy() overrides = self.overrides.copy()
overrides.update(kwargs) overrides.update(kwargs)
overrides['mode'] = 'export'
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.task = self.task args.task = self.task
if args.imgsz == DEFAULT_CFG.imgsz: if args.imgsz == DEFAULT_CFG.imgsz:
@ -309,6 +319,11 @@ class YOLO:
**kwargs (Any): Any number of arguments representing the training configuration. **kwargs (Any): Any number of arguments representing the training configuration.
""" """
self._check_is_pytorch_model() self._check_is_pytorch_model()
if self.session: # Ultralytics HUB session
if any(kwargs):
LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
kwargs = self.session.train_args
self.session.check_disk_space()
check_pip_update_available() check_pip_update_available()
overrides = self.overrides.copy() overrides = self.overrides.copy()
overrides.update(kwargs) overrides.update(kwargs)

@ -277,6 +277,8 @@ class Masks(SimpleClass):
self.masks = masks # N, h, w self.masks = masks # N, h, w
self.orig_shape = orig_shape self.orig_shape = orig_shape
@property
@lru_cache(maxsize=1)
def segments(self): def segments(self):
# Segments-deprecated (normalized) # Segments-deprecated (normalized)
LOGGER.warning("WARNING ⚠️ 'Masks.segments' is deprecated. Use 'Masks.xyn' for segments (normalized) and " LOGGER.warning("WARNING ⚠️ 'Masks.segments' is deprecated. Use 'Masks.xyn' for segments (normalized) and "

@ -321,10 +321,13 @@ def is_online() -> bool:
bool: True if connection is successful, False otherwise. bool: True if connection is successful, False otherwise.
""" """
import socket import socket
with contextlib.suppress(Exception):
host = socket.gethostbyname('www.github.com') for server in '1.1.1.1', '8.8.8.8', '223.5.5.5': # Cloudflare, Google, AliDNS:
socket.create_connection((host, 80), timeout=2) try:
return True socket.create_connection((server, 53), timeout=2) # connect to (server, port=53)
return True
except (socket.timeout, socket.gaierror, OSError):
continue
return False return False
@ -586,7 +589,7 @@ def set_sentry():
logging.getLogger(logger).setLevel(logging.CRITICAL) logging.getLogger(logger).setLevel(logging.CRITICAL)
def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.2'): def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.3'):
""" """
Loads a global Ultralytics settings YAML file or creates one with default values if it does not exist. Loads a global Ultralytics settings YAML file or creates one with default values if it does not exist.
@ -609,8 +612,9 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.2'):
'datasets_dir': str(datasets_root / 'datasets'), # default datasets directory. 'datasets_dir': str(datasets_root / 'datasets'), # default datasets directory.
'weights_dir': str(root / 'weights'), # default weights directory. 'weights_dir': str(root / 'weights'), # default weights directory.
'runs_dir': str(root / 'runs'), # default runs directory. 'runs_dir': str(root / 'runs'), # default runs directory.
'sync': True, # sync analytics to help with YOLO development
'uuid': hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(), # anonymized uuid hash 'uuid': hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(), # anonymized uuid hash
'sync': True, # sync analytics to help with YOLO development
'api_key': '', # Ultralytics HUB API key (https://hub.ultralytics.com/)
'settings_version': version} # Ultralytics settings version 'settings_version': version} # Ultralytics settings version
with torch_distributed_zero_first(RANK): with torch_distributed_zero_first(RANK):

@ -25,7 +25,7 @@ def on_pretrain_routine_end(trainer):
mlflow_location = os.environ['MLFLOW_TRACKING_URI'] # "http://192.168.xxx.xxx:5000" mlflow_location = os.environ['MLFLOW_TRACKING_URI'] # "http://192.168.xxx.xxx:5000"
mlflow.set_tracking_uri(mlflow_location) mlflow.set_tracking_uri(mlflow_location)
experiment_name = trainer.args.project or 'YOLOv8' experiment_name = trainer.args.project or '/Shared/YOLOv8'
experiment = mlflow.get_experiment_by_name(experiment_name) experiment = mlflow.get_experiment_by_name(experiment_name)
if experiment is None: if experiment is None:
mlflow.create_experiment(experiment_name) mlflow.create_experiment(experiment_name)
@ -33,16 +33,15 @@ def on_pretrain_routine_end(trainer):
prefix = colorstr('MLFlow: ') prefix = colorstr('MLFlow: ')
try: try:
run, active_run = mlflow, mlflow.start_run() if mlflow else None run, active_run = mlflow, mlflow.active_run()
if active_run is not None: if not active_run:
run_id = active_run.info.run_id active_run = mlflow.start_run(experiment_id=experiment.experiment_id)
LOGGER.info(f'{prefix}Using run_id({run_id}) at {mlflow_location}') run_id = active_run.info.run_id
LOGGER.info(f'{prefix}Using run_id({run_id}) at {mlflow_location}')
run.log_params(vars(trainer.model.args))
except Exception as err: except Exception as err:
LOGGER.error(f'{prefix}Failing init - {repr(err)}') LOGGER.error(f'{prefix}Failing init - {repr(err)}')
LOGGER.warning(f'{prefix}Continuing without Mlflow') LOGGER.warning(f'{prefix}Continuing without Mlflow')
run = None
run.log_params(vars(trainer.model.args))
def on_fit_epoch_end(trainer): def on_fit_epoch_end(trainer):

@ -142,7 +142,7 @@ def check_pip_update_available():
bool: True if an update is available, False otherwise. bool: True if an update is available, False otherwise.
""" """
if ONLINE and is_pip_package(): if ONLINE and is_pip_package():
with contextlib.suppress(ConnectionError): with contextlib.suppress(Exception):
from ultralytics import __version__ from ultralytics import __version__
latest = check_latest_pypi_version() latest = check_latest_pypi_version()
if pkg.parse_version(__version__) < pkg.parse_version(latest): # update is available if pkg.parse_version(__version__) < pkg.parse_version(latest): # update is available

@ -12,7 +12,7 @@ import requests
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from ultralytics.yolo.utils import LOGGER, checks, is_online from ultralytics.yolo.utils import LOGGER, checks, emojis, is_online
GITHUB_ASSET_NAMES = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] + \ GITHUB_ASSET_NAMES = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] + \
[f'yolov5{size}u.pt' for size in 'nsmlx'] + \ [f'yolov5{size}u.pt' for size in 'nsmlx'] + \
@ -113,9 +113,9 @@ def safe_download(url,
f.unlink() # remove partial downloads f.unlink() # remove partial downloads
except Exception as e: except Exception as e:
if i == 0 and not is_online(): if i == 0 and not is_online():
raise ConnectionError(f'❌ Download failure for {url}. Environment is not online.') from e raise ConnectionError(emojis(f'❌ Download failure for {url}. Environment is not online.')) from e
elif i >= retry: elif i >= retry:
raise ConnectionError(f'❌ Download failure for {url}. Retry limit reached.') from e raise ConnectionError(emojis(f'❌ Download failure for {url}. Retry limit reached.')) from e
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...') LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
if unzip and f.exists() and f.suffix in ('.zip', '.tar', '.gz'): if unzip and f.exists() and f.suffix in ('.zip', '.tar', '.gz'):

@ -114,7 +114,7 @@ class Annotator:
self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255 self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
if im_gpu.device != masks.device: if im_gpu.device != masks.device:
im_gpu = im_gpu.to(masks.device) im_gpu = im_gpu.to(masks.device)
colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)
colors = colors[:, None, None] # shape(n,1,1,3) colors = colors[:, None, None] # shape(n,1,1,3)
masks = masks.unsqueeze(3) # shape(n,h,w,1) masks = masks.unsqueeze(3) # shape(n,h,w,1)
masks_color = masks * (colors * alpha) # shape(n,h,w,3) masks_color = masks * (colors * alpha) # shape(n,h,w,3)

@ -78,7 +78,7 @@ class SegmentationPredictor(DetectionPredictor):
for j, d in enumerate(reversed(det)): for j, d in enumerate(reversed(det)):
c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item()) c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item())
if self.args.save_txt: # Write to file if self.args.save_txt: # Write to file
seg = mask.segments[len(det) - j - 1].copy().reshape(-1) # reversed mask.segments, (n,2) to (n*2) seg = mask.xyn[len(det) - j - 1].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2)
line = (c, *seg) + (conf, ) * self.args.save_conf + (() if id is None else (id, )) line = (c, *seg) + (conf, ) * self.args.save_conf + (() if id is None else (id, ))
with open(f'{self.txt_path}.txt', 'a') as f: with open(f'{self.txt_path}.txt', 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n') f.write(('%g ' * len(line)).rstrip() % line + '\n')

Loading…
Cancel
Save