ultralytics 8.0.48 Edge TPU fix and Metrics updates (#1171)
				
					
				
			Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: majid nasiri <majnasai@gmail.com>
This commit is contained in:
		
							
								
								
									
										77
									
								
								.github/workflows/ci.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										77
									
								
								.github/workflows/ci.yaml
									
									
									
									
										vendored
									
									
								
							| @ -12,6 +12,56 @@ on: | ||||
|     - cron: '0 0 * * *'  # runs at 00:00 UTC every day | ||||
|  | ||||
| jobs: | ||||
|   HUB: | ||||
|     runs-on: ${{ matrix.os }} | ||||
|     strategy: | ||||
|       fail-fast: false | ||||
|       matrix: | ||||
|         os: [ubuntu-latest] | ||||
|         python-version: ['3.10'] | ||||
|         model: [yolov5n] | ||||
|     steps: | ||||
|       - uses: actions/checkout@v3 | ||||
|       - uses: actions/setup-python@v4 | ||||
|         with: | ||||
|           python-version: ${{ matrix.python-version }} | ||||
|       - name: Get cache dir  # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow | ||||
|         id: pip-cache | ||||
|         run: echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT | ||||
|         shell: bash  # for Windows compatibility | ||||
|       - name: Cache pip | ||||
|         uses: actions/cache@v3 | ||||
|         with: | ||||
|           path: ${{ steps.pip-cache.outputs.dir }} | ||||
|           key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ hashFiles('requirements.txt') }} | ||||
|           restore-keys: ${{ runner.os }}-${{ matrix.python-version }}-pip- | ||||
|       - name: Install requirements | ||||
|         shell: bash  # for Windows compatibility | ||||
|         run: | | ||||
|           python -m pip install --upgrade pip wheel | ||||
|           pip install -e . --extra-index-url https://download.pytorch.org/whl/cpu | ||||
|       - name: Check environment | ||||
|         run: | | ||||
|           echo "RUNNER_OS is ${{ runner.os }}" | ||||
|           echo "GITHUB_EVENT_NAME is ${{ github.event_name }}" | ||||
|           echo "GITHUB_WORKFLOW is ${{ github.workflow }}" | ||||
|           echo "GITHUB_ACTOR is ${{ github.actor }}" | ||||
|           echo "GITHUB_REPOSITORY is ${{ github.repository }}" | ||||
|           echo "GITHUB_REPOSITORY_OWNER is ${{ github.repository_owner }}" | ||||
|           python --version | ||||
|           pip --version | ||||
|           pip list | ||||
|       - name: Test HUB training | ||||
|         shell: python | ||||
|         env: | ||||
|           APIKEY: ${{ secrets.ULTRALYTICS_HUB_APIKEY }} | ||||
|         run: | | ||||
|           import os | ||||
|           from ultralytics import hub | ||||
|           key = os.environ['APIKEY'] | ||||
|           hub.reset_model(key) | ||||
|           hub.start(key) | ||||
|  | ||||
|   Benchmarks: | ||||
|     runs-on: ${{ matrix.os }} | ||||
|     strategy: | ||||
| @ -25,12 +75,16 @@ jobs: | ||||
|       - uses: actions/setup-python@v4 | ||||
|         with: | ||||
|           python-version: ${{ matrix.python-version }} | ||||
|       #- name: Cache pip | ||||
|       #  uses: actions/cache@v3 | ||||
|       #  with: | ||||
|       #    path: ~/.cache/pip | ||||
|       #    key: ${{ runner.os }}-Benchmarks-${{ hashFiles('requirements.txt') }} | ||||
|       #    restore-keys: ${{ runner.os }}-Benchmarks- | ||||
|       - name: Get cache dir  # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow | ||||
|         id: pip-cache | ||||
|         run: echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT | ||||
|         shell: bash  # for Windows compatibility | ||||
|       - name: Cache pip | ||||
|         uses: actions/cache@v3 | ||||
|         with: | ||||
|           path: ${{ steps.pip-cache.outputs.dir }} | ||||
|           key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ hashFiles('requirements.txt') }} | ||||
|           restore-keys: ${{ runner.os }}-${{ matrix.python-version }}-pip- | ||||
|       - name: Install requirements | ||||
|         shell: bash  # for Windows compatibility | ||||
|         run: | | ||||
| @ -120,17 +174,6 @@ jobs: | ||||
|           python --version | ||||
|           pip --version | ||||
|           pip list | ||||
|       - name: Test pip package | ||||
|         shell: python | ||||
|         env: | ||||
|           APIKEY: ${{ secrets.ULTRALYTICS_HUB_APIKEY }} | ||||
|         run: | | ||||
|           import os | ||||
|           import ultralytics | ||||
|           key = os.environ['APIKEY'] | ||||
|           ultralytics.checks() | ||||
|           # ultralytics.reset_model(key)  # reset trained model | ||||
|           # ultralytics.start(key)  # train model | ||||
|       - name: Test detection | ||||
|         shell: bash  # for Windows compatibility | ||||
|         run: | | ||||
|  | ||||
| @ -28,6 +28,29 @@ predictor's call method. | ||||
|         probs = r.probs  # Class probabilities for classification outputs | ||||
|     ``` | ||||
|  | ||||
| ## Sources | ||||
|  | ||||
| YOLOv8 can run inference on a variety of sources. The table below lists the various sources that can be used as input | ||||
| for YOLOv8, along with the required format and notes. Sources include images, URLs, PIL images, OpenCV, numpy arrays, | ||||
| torch tensors, CSV files, videos, directories, globs, YouTube videos, and streams. The table also indicates whether each | ||||
| source can be used as a stream and the model argument required for that source. | ||||
|  | ||||
| | source     | stream  | model(arg)                                 | type           | notes            | | ||||
| |------------|---------|--------------------------------------------|----------------|------------------| | ||||
| | image      |         | `'im.jpg'`                                 | `str`, `Path`  |                  | | ||||
| | URL        |         | `'https://ultralytics.com/images/bus.jpg'` | `str`          |                  | | ||||
| | screenshot |         | `'screen'`                                 | `str`          |                  | | ||||
| | PIL        |         | `Image.open('im.jpg')`                     | `PIL.Image`    | HWC, RGB         | | ||||
| | OpenCV     |         | `cv2.imread('im.jpg')[:,:,::-1]`           | `np.ndarray`   | HWC, BGR to RGB  | | ||||
| | numpy      |         | `np.zeros((640,1280,3))`                   | `np.ndarray`   | HWC              | | ||||
| | torch      |         | `torch.zeros(16,3,320,640)`                | `torch.Tensor` | BCHW, RGB        | | ||||
| | CSV        |         | `'sources.csv'`                            | `str`, `Path`  | RTSP, RTMP, HTTP |          | ||||
| | video      | ✓ | `'vid.mp4'`                                | `str`, `Path`  |                  | | ||||
| | directory  | ✓ | `'path/'`                                  | `str`, `Path`  |                  | | ||||
| | glob       | ✓ | `path/*.jpg'`                              | `str`          | Use `*` operator | | ||||
| | YouTube    | ✓ | `'https://youtu.be/Zgi9g1ksQHc'`           | `str`          |                  | | ||||
| | stream     | ✓ | `'rtsp://example.com/media.mp4'`           | `str`          | RTSP, RTMP, HTTP | | ||||
|  | ||||
| ## Working with Results | ||||
|  | ||||
| Results object consists of these component objects: | ||||
|  | ||||
| @ -645,7 +645,7 @@ | ||||
|       "cell_type": "code", | ||||
|       "source": [ | ||||
|         "# Git clone install (for development)\n", | ||||
|         "!git clone https://github.com/ultralytics/ultralytics\n", | ||||
|         "!git clone https://github.com/ultralytics/ultralytics -b main\n", | ||||
|         "%pip install -qe ultralytics" | ||||
|       ], | ||||
|       "metadata": { | ||||
|  | ||||
| @ -3,7 +3,7 @@ | ||||
| import subprocess | ||||
| from pathlib import Path | ||||
|  | ||||
| from ultralytics.yolo.utils import LINUX, ROOT, SETTINGS, checks | ||||
| from ultralytics.yolo.utils import LINUX, ONLINE, ROOT, SETTINGS | ||||
|  | ||||
| MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n' | ||||
| CFG = 'yolov8n' | ||||
| @ -49,7 +49,7 @@ def test_val_classify(): | ||||
| # Predict checks ------------------------------------------------------------------------------------------------------- | ||||
| def test_predict_detect(): | ||||
|     run(f"yolo predict model={MODEL}.pt source={ROOT / 'assets'} imgsz=32 save save_crop save_txt") | ||||
|     if checks.check_online(): | ||||
|     if ONLINE: | ||||
|         run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32') | ||||
|         run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=32') | ||||
|         run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_portrait_min.mov imgsz=32') | ||||
|  | ||||
| @ -9,7 +9,7 @@ from PIL import Image | ||||
|  | ||||
| from ultralytics import YOLO | ||||
| from ultralytics.yolo.data.build import load_inference_source | ||||
| from ultralytics.yolo.utils import LINUX, ROOT, SETTINGS, checks | ||||
| from ultralytics.yolo.utils import LINUX, ONLINE, ROOT, SETTINGS | ||||
|  | ||||
| MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt' | ||||
| CFG = 'yolov8n.yaml' | ||||
| @ -58,7 +58,7 @@ def test_predict_img(): | ||||
|     batch = [ | ||||
|         str(SOURCE),  # filename | ||||
|         Path(SOURCE),  # Path | ||||
|         'https://ultralytics.com/images/zidane.jpg' if checks.check_online() else SOURCE,  # URI | ||||
|         'https://ultralytics.com/images/zidane.jpg' if ONLINE else SOURCE,  # URI | ||||
|         cv2.imread(str(SOURCE)),  # OpenCV | ||||
|         Image.open(SOURCE),  # PIL | ||||
|         np.zeros((320, 640, 3))]  # numpy | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| # Ultralytics YOLO 🚀, GPL-3.0 license | ||||
|  | ||||
| __version__ = '8.0.47' | ||||
| __version__ = '8.0.48' | ||||
|  | ||||
| from ultralytics.yolo.engine.model import YOLO | ||||
| from ultralytics.yolo.utils.checks import check_yolo as checks | ||||
|  | ||||
| @ -3,11 +3,11 @@ | ||||
| import requests | ||||
|  | ||||
| from ultralytics.hub.auth import Auth | ||||
| from ultralytics.hub.session import HubTrainingSession | ||||
| from ultralytics.hub.utils import split_key | ||||
| from ultralytics.hub.session import HUBTrainingSession | ||||
| from ultralytics.hub.utils import PREFIX, split_key | ||||
| from ultralytics.yolo.engine.exporter import EXPORT_FORMATS_LIST | ||||
| from ultralytics.yolo.engine.model import YOLO | ||||
| from ultralytics.yolo.utils import LOGGER, PREFIX, emojis | ||||
| from ultralytics.yolo.utils import LOGGER, emojis | ||||
|  | ||||
| # Define all export formats | ||||
| EXPORT_FORMATS_HUB = EXPORT_FORMATS_LIST + ['ultralytics_tflite', 'ultralytics_coreml'] | ||||
| @ -18,23 +18,19 @@ def start(key=''): | ||||
|     Start training models with Ultralytics HUB. Usage: from ultralytics.hub import start; start('API_KEY') | ||||
|     """ | ||||
|     auth = Auth(key) | ||||
|     try: | ||||
|         if not auth.get_state(): | ||||
|             model_id = request_api_key(auth) | ||||
|         else: | ||||
|             _, model_id = split_key(key) | ||||
|     if not auth.get_state(): | ||||
|         model_id = request_api_key(auth) | ||||
|     else: | ||||
|         _, model_id = split_key(key) | ||||
|  | ||||
|         if not model_id: | ||||
|             raise ConnectionError(emojis('Connecting with global API key is not currently supported. ❌')) | ||||
|     if not model_id: | ||||
|         raise ConnectionError(emojis('Connecting with global API key is not currently supported. ❌')) | ||||
|  | ||||
|         session = HubTrainingSession(model_id=model_id, auth=auth) | ||||
|         session.check_disk_space() | ||||
|     session = HUBTrainingSession(model_id=model_id, auth=auth) | ||||
|     session.check_disk_space() | ||||
|  | ||||
|         model = YOLO(session.input_file) | ||||
|         session.register_callbacks(model) | ||||
|         model.train(**session.train_args) | ||||
|     except Exception as e: | ||||
|         LOGGER.warning(f'{PREFIX}{e}') | ||||
|     model = YOLO(model=session.model_file, session=session) | ||||
|     model.train(**session.train_args) | ||||
|  | ||||
|  | ||||
| def request_api_key(auth, max_attempts=3): | ||||
| @ -62,9 +58,9 @@ def reset_model(key=''): | ||||
|     r = requests.post('https://api.ultralytics.com/model-reset', json={'apiKey': api_key, 'modelId': model_id}) | ||||
|  | ||||
|     if r.status_code == 200: | ||||
|         LOGGER.info(f'{PREFIX}model reset successfully') | ||||
|         LOGGER.info(f'{PREFIX}Model reset successfully') | ||||
|         return | ||||
|     LOGGER.warning(f'{PREFIX}model reset failure {r.status_code} {r.reason}') | ||||
|     LOGGER.warning(f'{PREFIX}Model reset failure {r.status_code} {r.reason}') | ||||
|  | ||||
|  | ||||
| def export_model(key='', format='torchscript'): | ||||
| @ -76,7 +72,7 @@ def export_model(key='', format='torchscript'): | ||||
|                           'apiKey': api_key, | ||||
|                           'modelId': model_id, | ||||
|                           'format': format}) | ||||
|     assert (r.status_code == 200), f'{PREFIX}{format} export failure {r.status_code} {r.reason}' | ||||
|     assert r.status_code == 200, f'{PREFIX}{format} export failure {r.status_code} {r.reason}' | ||||
|     LOGGER.info(f'{PREFIX}{format} export started ✅') | ||||
|  | ||||
|  | ||||
| @ -89,7 +85,7 @@ def get_export(key='', format='torchscript'): | ||||
|                           'apiKey': api_key, | ||||
|                           'modelId': model_id, | ||||
|                           'format': format}) | ||||
|     assert (r.status_code == 200), f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}' | ||||
|     assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}' | ||||
|     return r.json() | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -1,30 +1,27 @@ | ||||
| # Ultralytics YOLO 🚀, GPL-3.0 license | ||||
| import json | ||||
| import signal | ||||
| import sys | ||||
| from pathlib import Path | ||||
| from time import sleep, time | ||||
| from time import sleep | ||||
|  | ||||
| import requests | ||||
|  | ||||
| from ultralytics.hub.utils import HUB_API_ROOT, check_dataset_disk_space, smart_request | ||||
| from ultralytics.yolo.utils import LOGGER, PREFIX, __version__, emojis, is_colab, threaded | ||||
| from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params | ||||
| from ultralytics.yolo.utils import LOGGER, PREFIX, __version__, checks, emojis, is_colab, threaded | ||||
|  | ||||
| AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local' | ||||
| session = None | ||||
|  | ||||
|  | ||||
| class HubTrainingSession: | ||||
| class HUBTrainingSession: | ||||
|  | ||||
|     def __init__(self, model_id, auth): | ||||
|         self.agent_id = None  # identifies which instance is communicating with server | ||||
|         self.model_id = model_id | ||||
|         self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}' | ||||
|         self.auth_header = auth.get_auth_header() | ||||
|         self._rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0}  # rate limits (seconds) | ||||
|         self._timers = {}  # rate limit timers (seconds) | ||||
|         self._metrics_queue = {}  # metrics queue | ||||
|         self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0}  # rate limits (seconds) | ||||
|         self.timers = {}  # rate limit timers (seconds) | ||||
|         self.metrics_queue = {}  # metrics queue | ||||
|         self.model = self._get_model() | ||||
|         self.alive = True | ||||
|         self._start_heartbeat()  # start heartbeats | ||||
| @ -50,16 +47,15 @@ class HubTrainingSession: | ||||
|         self.alive = False | ||||
|  | ||||
|     def upload_metrics(self): | ||||
|         payload = {'metrics': self._metrics_queue.copy(), 'type': 'metrics'} | ||||
|         smart_request(f'{self.api_url}', json=payload, headers=self.auth_header, code=2) | ||||
|         payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'} | ||||
|         smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2) | ||||
|  | ||||
|     def _get_model(self): | ||||
|         # Returns model from database by id | ||||
|         api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}' | ||||
|         headers = self.auth_header | ||||
|  | ||||
|         try: | ||||
|             response = smart_request(api_url, method='get', headers=headers, thread=False, code=0) | ||||
|             response = smart_request('get', api_url, headers=self.auth_header, thread=False, code=0) | ||||
|             data = response.json().get('data', None) | ||||
|  | ||||
|             if data.get('status', None) == 'trained': | ||||
| @ -82,11 +78,8 @@ class HubTrainingSession: | ||||
|                 'cache': data['cache'], | ||||
|                 'data': data['data']} | ||||
|  | ||||
|             self.input_file = data.get('cfg', data['weights']) | ||||
|  | ||||
|             # hack for yolov5 cfg adds u | ||||
|             if 'cfg' in data and 'yolov5' in data['cfg']: | ||||
|                 self.input_file = data['cfg'].replace('.yaml', 'u.yaml') | ||||
|             self.model_file = data.get('cfg', data['weights']) | ||||
|             self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False)  # YOLOv5->YOLOv5u | ||||
|  | ||||
|             return data | ||||
|         except requests.exceptions.ConnectionError as e: | ||||
| @ -98,86 +91,44 @@ class HubTrainingSession: | ||||
|         if not check_dataset_disk_space(self.model['data']): | ||||
|             raise MemoryError('Not enough disk space') | ||||
|  | ||||
|     def register_callbacks(self, trainer): | ||||
|         trainer.add_callback('on_pretrain_routine_end', self.on_pretrain_routine_end) | ||||
|         trainer.add_callback('on_fit_epoch_end', self.on_fit_epoch_end) | ||||
|         trainer.add_callback('on_model_save', self.on_model_save) | ||||
|         trainer.add_callback('on_train_end', self.on_train_end) | ||||
|  | ||||
|     def on_pretrain_routine_end(self, trainer): | ||||
|         """ | ||||
|         Start timer for upload rate limit. | ||||
|         This method does not use trainer. It is passed to all callbacks by default. | ||||
|         """ | ||||
|         # Start timer for upload rate limit | ||||
|         LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀') | ||||
|         self._timers = {'metrics': time(), 'ckpt': time()}  # start timer on self.rate_limit | ||||
|  | ||||
|     def on_fit_epoch_end(self, trainer): | ||||
|         # Upload metrics after val end | ||||
|         all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics} | ||||
|  | ||||
|         if trainer.epoch == 0: | ||||
|             model_info = { | ||||
|                 'model/parameters': get_num_params(trainer.model), | ||||
|                 'model/GFLOPs': round(get_flops(trainer.model), 3), | ||||
|                 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)} | ||||
|             all_plots = {**all_plots, **model_info} | ||||
|         self._metrics_queue[trainer.epoch] = json.dumps(all_plots) | ||||
|         if time() - self._timers['metrics'] > self._rate_limits['metrics']: | ||||
|             self.upload_metrics() | ||||
|             self._timers['metrics'] = time()  # reset timer | ||||
|             self._metrics_queue = {}  # reset queue | ||||
|  | ||||
|     def on_model_save(self, trainer): | ||||
|         # Upload checkpoints with rate limiting | ||||
|         is_best = trainer.best_fitness == trainer.fitness | ||||
|         if time() - self._timers['ckpt'] > self._rate_limits['ckpt']: | ||||
|             LOGGER.info(f'{PREFIX}Uploading checkpoint {self.model_id}') | ||||
|             self._upload_model(trainer.epoch, trainer.last, is_best) | ||||
|             self._timers['ckpt'] = time()  # reset timer | ||||
|  | ||||
|     def on_train_end(self, trainer): | ||||
|         # Upload final model and metrics with exponential standoff | ||||
|         LOGGER.info(f'{PREFIX}Training completed successfully ✅\n' | ||||
|                     f'{PREFIX}Uploading final {self.model_id}') | ||||
|  | ||||
|         self._upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True) | ||||
|         self.alive = False  # stop heartbeats | ||||
|         LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀') | ||||
|  | ||||
|     def _upload_model(self, epoch, weights, is_best=False, map=0.0, final=False): | ||||
|     def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False): | ||||
|         # Upload a model to HUB | ||||
|         if Path(weights).is_file(): | ||||
|             with open(weights, 'rb') as f: | ||||
|                 file = f.read() | ||||
|         else: | ||||
|             LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload failed. Missing model {weights}.') | ||||
|             LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.') | ||||
|             file = None | ||||
|         url = f'{self.api_url}/upload' | ||||
|         # url = 'http://httpbin.org/post'  # for debug | ||||
|         data = {'epoch': epoch} | ||||
|         if final: | ||||
|             data.update({'type': 'final', 'map': map}) | ||||
|             smart_request('post', | ||||
|                           url, | ||||
|                           data=data, | ||||
|                           files={'best.pt': file}, | ||||
|                           headers=self.auth_header, | ||||
|                           retry=10, | ||||
|                           timeout=3600, | ||||
|                           thread=False, | ||||
|                           progress=True, | ||||
|                           code=4) | ||||
|         else: | ||||
|             data.update({'type': 'epoch', 'isBest': bool(is_best)}) | ||||
|  | ||||
|         smart_request(f'{self.api_url}/upload', | ||||
|                       data=data, | ||||
|                       files={'best.pt' if final else 'last.pt': file}, | ||||
|                       headers=self.auth_header, | ||||
|                       retry=10 if final else None, | ||||
|                       timeout=3600 if final else None, | ||||
|                       code=4 if final else 3) | ||||
|             smart_request('post', url, data=data, files={'last.pt': file}, headers=self.auth_header, code=3) | ||||
|  | ||||
|     @threaded | ||||
|     def _start_heartbeat(self): | ||||
|         while self.alive: | ||||
|             r = smart_request(f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}', | ||||
|             r = smart_request('post', | ||||
|                               f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}', | ||||
|                               json={ | ||||
|                                   'agent': AGENT_NAME, | ||||
|                                   'agentId': self.agent_id}, | ||||
|                               headers=self.auth_header, | ||||
|                               retry=0, | ||||
|                               code=5, | ||||
|                               thread=False) | ||||
|                               thread=False)  # already in a thread | ||||
|             self.agent_id = r.json().get('data', {}).get('agentId', None) | ||||
|             sleep(self._rate_limits['heartbeat']) | ||||
|             sleep(self.rate_limits['heartbeat']) | ||||
|  | ||||
| @ -10,13 +10,13 @@ from pathlib import Path | ||||
| from random import random | ||||
|  | ||||
| import requests | ||||
| from tqdm import tqdm | ||||
|  | ||||
| from ultralytics.yolo.utils import (DEFAULT_CFG_DICT, ENVIRONMENT, LOGGER, RANK, SETTINGS, TESTS_RUNNING, TryExcept, | ||||
|                                     __version__, colorstr, emojis, get_git_origin_url, is_colab, is_git_dir, | ||||
|                                     is_pip_package) | ||||
| from ultralytics.yolo.utils.checks import check_online | ||||
| from ultralytics.yolo.utils import (DEFAULT_CFG_DICT, ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING, | ||||
|                                     TQDM_BAR_FORMAT, TryExcept, __version__, colorstr, emojis, get_git_origin_url, | ||||
|                                     is_colab, is_git_dir, is_pip_package) | ||||
|  | ||||
| PREFIX = colorstr('Ultralytics: ') | ||||
| PREFIX = colorstr('Ultralytics HUB: ') | ||||
| HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.' | ||||
| HUB_API_ROOT = os.environ.get('ULTRALYTICS_HUB_API', 'https://api.ultralytics.com') | ||||
|  | ||||
| @ -60,7 +60,6 @@ def request_with_credentials(url: str) -> any: | ||||
|     return output.eval_js('_hub_tmp') | ||||
|  | ||||
|  | ||||
| # Deprecated TODO: eliminate this function? | ||||
| def split_key(key=''): | ||||
|     """ | ||||
|     Verify and split a 'api_key[sep]model_id' string, sep is one of '.' or '_' | ||||
| @ -84,36 +83,61 @@ def split_key(key=''): | ||||
|     return api_key, model_id | ||||
|  | ||||
|  | ||||
| def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method='post', verbose=True, **kwargs): | ||||
| def requests_with_progress(method, url, **kwargs): | ||||
|     """ | ||||
|     Make an HTTP request using the specified method and URL, with an optional progress bar. | ||||
|  | ||||
|     Args: | ||||
|         method (str): The HTTP method to use (e.g. 'GET', 'POST'). | ||||
|         url (str): The URL to send the request to. | ||||
|         progress (bool, optional): Whether to display a progress bar. Defaults to False. | ||||
|         **kwargs: Additional keyword arguments to pass to the underlying `requests.request` function. | ||||
|  | ||||
|     Returns: | ||||
|         requests.Response: The response from the HTTP request. | ||||
|  | ||||
|     """ | ||||
|     progress = kwargs.pop('progress', False) | ||||
|     if not progress: | ||||
|         return requests.request(method, url, **kwargs) | ||||
|     response = requests.request(method, url, stream=True, **kwargs) | ||||
|     total = int(response.headers.get('content-length', 0))  # total size | ||||
|     pbar = tqdm(total=total, unit='B', unit_scale=True, unit_divisor=1024, bar_format=TQDM_BAR_FORMAT) | ||||
|     for data in response.iter_content(chunk_size=1024): | ||||
|         pbar.update(len(data)) | ||||
|     pbar.close() | ||||
|     return response | ||||
|  | ||||
|  | ||||
| def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbose=True, progress=False, **kwargs): | ||||
|     """ | ||||
|     Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout. | ||||
|  | ||||
|     Args: | ||||
|         *args: Positional arguments to be passed to the requests function specified in method. | ||||
|         method (str): The HTTP method to use for the request. Choices are 'post' and 'get'. | ||||
|         url (str): The URL to make the request to. | ||||
|         retry (int, optional): Number of retries to attempt before giving up. Default is 3. | ||||
|         timeout (int, optional): Timeout in seconds after which the function will give up retrying. Default is 30. | ||||
|         thread (bool, optional): Whether to execute the request in a separate daemon thread. Default is True. | ||||
|         code (int, optional): An identifier for the request, used for logging purposes. Default is -1. | ||||
|         method (str, optional): The HTTP method to use for the request. Choices are 'post' and 'get'. Default is 'post'. | ||||
|         verbose (bool, optional): A flag to determine whether to print out to console or not. Default is True. | ||||
|         progress (bool, optional): Whether to show a progress bar during the request. Default is False. | ||||
|         **kwargs: Keyword arguments to be passed to the requests function specified in method. | ||||
|  | ||||
|     Returns: | ||||
|         requests.Response: The HTTP response object. If the request is executed in a separate thread, returns None. | ||||
|  | ||||
|     """ | ||||
|     retry_codes = (408, 500)  # retry only these codes | ||||
|  | ||||
|     @TryExcept(verbose=verbose) | ||||
|     def func(*func_args, **func_kwargs): | ||||
|     def func(func_method, func_url, **func_kwargs): | ||||
|         r = None  # response | ||||
|         t0 = time.time()  # initial time for timer | ||||
|         for i in range(retry + 1): | ||||
|             if (time.time() - t0) > timeout: | ||||
|                 break | ||||
|             if method == 'post': | ||||
|                 r = requests.post(*func_args, **func_kwargs)  # i.e. post(url, data, json, files) | ||||
|             elif method == 'get': | ||||
|                 r = requests.get(*func_args, **func_kwargs)  # i.e. get(url, data, json, files) | ||||
|             r = requests_with_progress(func_method, func_url, **func_kwargs)  # i.e. get(url, data, json, files) | ||||
|             if r.status_code == 200: | ||||
|                 break | ||||
|             try: | ||||
| @ -134,6 +158,8 @@ def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method='post | ||||
|             time.sleep(2 ** i)  # exponential standoff | ||||
|         return r | ||||
|  | ||||
|     args = method, url | ||||
|     kwargs['progress'] = progress | ||||
|     if thread: | ||||
|         threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start() | ||||
|     else: | ||||
| @ -157,8 +183,8 @@ class Traces: | ||||
|         self.enabled = \ | ||||
|             SETTINGS['sync'] and \ | ||||
|             RANK in {-1, 0} and \ | ||||
|             check_online() and \ | ||||
|             not TESTS_RUNNING and \ | ||||
|             ONLINE and \ | ||||
|             (is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git') | ||||
|  | ||||
|     def __call__(self, cfg, all_keys=False, traces_sample_rate=1.0): | ||||
| @ -182,13 +208,7 @@ class Traces: | ||||
|             trace = {'uuid': SETTINGS['uuid'], 'cfg': cfg, 'metadata': self.metadata} | ||||
|  | ||||
|             # Send a request to the HUB API to sync analytics | ||||
|             smart_request(f'{HUB_API_ROOT}/v1/usage/anonymous', | ||||
|                           json=trace, | ||||
|                           headers=None, | ||||
|                           code=3, | ||||
|                           retry=0, | ||||
|                           timeout=1.0, | ||||
|                           verbose=False) | ||||
|             smart_request('post', f'{HUB_API_ROOT}/v1/usage/anonymous', json=trace, code=3, retry=0, verbose=False) | ||||
|  | ||||
|  | ||||
| # Run below code on hub/utils init ------------------------------------------------------------------------------------- | ||||
|  | ||||
| @ -13,7 +13,7 @@ from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_P | ||||
|  | ||||
| CLI_HELP_MSG = \ | ||||
|     f""" | ||||
|     Arguments received: {str(['yolo'] + sys.argv[1:])}. Note that Ultralytics 'yolo' commands use the following syntax: | ||||
|     Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax: | ||||
|  | ||||
|         yolo TASK MODE ARGS | ||||
|  | ||||
| @ -217,6 +217,9 @@ def entrypoint(debug=''): | ||||
|         if a.startswith('--'): | ||||
|             LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.") | ||||
|             a = a[2:] | ||||
|         if a.endswith(','): | ||||
|             LOGGER.warning(f"WARNING ⚠️ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.") | ||||
|             a = a[:-1] | ||||
|         if '=' in a: | ||||
|             try: | ||||
|                 re.sub(r' *= *', '=', a)  # remove spaces around equals sign | ||||
| @ -284,6 +287,9 @@ def entrypoint(debug=''): | ||||
|     model = YOLO(model, task=task) | ||||
|  | ||||
|     # Task Update | ||||
|     if task and task != model.task: | ||||
|         LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. " | ||||
|                        f'This may produce errors.') | ||||
|     task = task or model.task | ||||
|     overrides['task'] = task | ||||
|  | ||||
|  | ||||
| @ -243,15 +243,12 @@ class Exporter: | ||||
|         if coreml:  # CoreML | ||||
|             f[4], _ = self._export_coreml() | ||||
|         if any((saved_model, pb, tflite, edgetpu, tfjs)):  # TensorFlow formats | ||||
|             LOGGER.warning('WARNING ⚠️ YOLOv8 TensorFlow export is still under development. ' | ||||
|                            'Please consider contributing to the effort if you have TF expertise. Thank you!') | ||||
|             nms = False | ||||
|             self.args.int8 |= edgetpu | ||||
|             f[5], s_model = self._export_saved_model() | ||||
|             if pb or tfjs:  # pb prerequisite to tfjs | ||||
|                 f[6], _ = self._export_pb(s_model) | ||||
|             if tflite: | ||||
|                 f[7], _ = self._export_tflite(s_model, nms=nms, agnostic_nms=self.args.agnostic_nms) | ||||
|                 f[7], _ = self._export_tflite(s_model, nms=False, agnostic_nms=self.args.agnostic_nms) | ||||
|             if edgetpu: | ||||
|                 f[8], _ = self._export_edgetpu(tflite_model=str( | ||||
|                     Path(f[5]) / (self.file.stem + '_full_integer_quant.tflite')))  # int8 in/out | ||||
| @ -619,20 +616,18 @@ class Exporter: | ||||
|     @try_export | ||||
|     def _export_edgetpu(self, tflite_model='', prefix=colorstr('Edge TPU:')): | ||||
|         # YOLOv8 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/ | ||||
|         LOGGER.warning(f'{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185') | ||||
|  | ||||
|         cmd = 'edgetpu_compiler --version' | ||||
|         help_url = 'https://coral.ai/docs/edgetpu/compiler/' | ||||
|         assert LINUX, f'export only supported on Linux. See {help_url}' | ||||
|         if subprocess.run(f'{cmd} > /dev/null', shell=True).returncode != 0: | ||||
|         if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0: | ||||
|             LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}') | ||||
|             sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0  # sudo installed on system | ||||
|             for c in ( | ||||
|                     # 'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',  # errors | ||||
|                     'wget --no-check-certificate -q -O - https://packages.cloud.google.com/apt/doc/apt-key.gpg | ' | ||||
|                     'sudo apt-key add -', | ||||
|                     'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | '  # no comma | ||||
|                     'sudo tee /etc/apt/sources.list.d/coral-edgetpu.list', | ||||
|                     'sudo apt-get update', | ||||
|                     'sudo apt-get install edgetpu-compiler'): | ||||
|                     'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -', | ||||
|                     'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list', | ||||
|                     'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'): | ||||
|                 subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True) | ||||
|         ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1] | ||||
|  | ||||
|  | ||||
| @ -43,7 +43,7 @@ class YOLO: | ||||
|         cfg (str): The model configuration if loaded from *.yaml file. | ||||
|         ckpt_path (str): The checkpoint file path. | ||||
|         overrides (dict): Overrides for the trainer object. | ||||
|         metrics_data (Any): The data for metrics. | ||||
|         metrics (Any): The data for metrics. | ||||
|  | ||||
|     Methods: | ||||
|         __call__(source=None, stream=False, **kwargs): | ||||
| @ -67,7 +67,7 @@ class YOLO: | ||||
|         list(ultralytics.yolo.engine.results.Results): The prediction results. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, model='yolov8n.pt', task=None) -> None: | ||||
|     def __init__(self, model='yolov8n.pt', task=None, session=None) -> None: | ||||
|         """ | ||||
|         Initializes the YOLO model. | ||||
|  | ||||
| @ -83,7 +83,8 @@ class YOLO: | ||||
|         self.cfg = None  # if loaded from *.yaml | ||||
|         self.ckpt_path = None | ||||
|         self.overrides = {}  # overrides for trainer object | ||||
|         self.metrics_data = None | ||||
|         self.metrics = None  # validation/training metrics | ||||
|         self.session = session  # HUB session | ||||
|  | ||||
|         # Load or create new YOLO model | ||||
|         suffix = Path(model).suffix | ||||
| @ -184,6 +185,7 @@ class YOLO: | ||||
|         self._check_is_pytorch_model() | ||||
|         self.model.fuse() | ||||
|  | ||||
|     @smart_inference_mode() | ||||
|     def predict(self, source=None, stream=False, **kwargs): | ||||
|         """ | ||||
|         Perform prediction using the YOLO model. | ||||
| @ -217,7 +219,6 @@ class YOLO: | ||||
|         is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics') | ||||
|         return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) | ||||
|  | ||||
|     @smart_inference_mode() | ||||
|     def track(self, source=None, stream=False, **kwargs): | ||||
|         from ultralytics.tracker import register_tracker | ||||
|         register_tracker(self) | ||||
| @ -252,7 +253,7 @@ class YOLO: | ||||
|  | ||||
|         validator = TASK_MAP[self.task][2](args=args) | ||||
|         validator(model=self.model) | ||||
|         self.metrics_data = validator.metrics | ||||
|         self.metrics = validator.metrics | ||||
|  | ||||
|         return validator.metrics | ||||
|  | ||||
| @ -314,12 +315,13 @@ class YOLO: | ||||
|         if not overrides.get('resume'):  # manually set model only if not resuming | ||||
|             self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml) | ||||
|             self.model = self.trainer.model | ||||
|         self.trainer.hub_session = self.session  # attach optional HUB session | ||||
|         self.trainer.train() | ||||
|         # update model and cfg after training | ||||
|         if RANK in {0, -1}: | ||||
|             self.model, _ = attempt_load_one_weight(str(self.trainer.best)) | ||||
|             self.overrides = self.model.args | ||||
|             self.metrics_data = getattr(self.trainer.validator, 'metrics', None)  # TODO: no metrics returned by DDP | ||||
|             self.metrics = getattr(self.trainer.validator, 'metrics', None)  # TODO: no metrics returned by DDP | ||||
|  | ||||
|     def to(self, device): | ||||
|         """ | ||||
| @ -352,15 +354,6 @@ class YOLO: | ||||
|         """ | ||||
|         return self.model.transforms if hasattr(self.model, 'transforms') else None | ||||
|  | ||||
|     @property | ||||
|     def metrics(self): | ||||
|         """ | ||||
|         Returns metrics if computed | ||||
|         """ | ||||
|         if not self.metrics_data: | ||||
|             LOGGER.info('No metrics data found! Run training or validation operation first.') | ||||
|         return self.metrics_data | ||||
|  | ||||
|     @staticmethod | ||||
|     def add_callback(event: str, func): | ||||
|         """ | ||||
|  | ||||
| @ -139,7 +139,8 @@ class Results: | ||||
|             annotator.masks(masks.data, colors=[colors(x, True) for x in boxes.cls], im_gpu=im) | ||||
|  | ||||
|         if logits is not None: | ||||
|             top5i = logits.argsort(0, descending=True)[:5].tolist()  # top 5 indices | ||||
|             n5 = min(len(self.names), 5) | ||||
|             top5i = logits.argsort(0, descending=True)[:n5].tolist()  # top 5 indices | ||||
|             text = f"{', '.join(f'{names[j] if names else j} {logits[j]:.2f}' for j in top5i)}, " | ||||
|             annotator.text((32, 32), text, txt_color=(255, 255, 255))  # TODO: allow setting colors | ||||
|  | ||||
|  | ||||
| @ -243,6 +243,24 @@ def is_docker() -> bool: | ||||
|         return False | ||||
|  | ||||
|  | ||||
| def is_online() -> bool: | ||||
|     """ | ||||
|     Check internet connectivity by attempting to connect to a known online host. | ||||
|  | ||||
|     Returns: | ||||
|         bool: True if connection is successful, False otherwise. | ||||
|     """ | ||||
|     import socket | ||||
|     with contextlib.suppress(Exception): | ||||
|         host = socket.gethostbyname('www.github.com') | ||||
|         socket.create_connection((host, 80), timeout=2) | ||||
|         return True | ||||
|     return False | ||||
|  | ||||
|  | ||||
| ONLINE = is_online() | ||||
|  | ||||
|  | ||||
| def is_pip_package(filepath: str = __name__) -> bool: | ||||
|     """ | ||||
|     Determines if the file at the given filepath is part of a pip package. | ||||
| @ -513,6 +531,7 @@ def set_sentry(): | ||||
|             RANK in {-1, 0} and \ | ||||
|             Path(sys.argv[0]).name == 'yolo' and \ | ||||
|             not TESTS_RUNNING and \ | ||||
|             ONLINE and \ | ||||
|             ((is_pip_package() and not is_git_dir()) or | ||||
|              (get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git' and get_git_branch() == 'main')): | ||||
|  | ||||
|  | ||||
| @ -151,4 +151,5 @@ def add_integration_callbacks(instance): | ||||
|  | ||||
|     for x in clearml_callbacks, comet_callbacks, hub_callbacks, tb_callbacks: | ||||
|         for k, v in x.items(): | ||||
|             instance.callbacks[k].append(v)  # callback[name].append(func) | ||||
|             if v not in instance.callbacks[k]:  # prevent duplicate callbacks addition | ||||
|                 instance.callbacks[k].append(v)  # callback[name].append(func) | ||||
|  | ||||
| @ -4,24 +4,33 @@ import json | ||||
| from time import time | ||||
|  | ||||
| from ultralytics.hub.utils import PREFIX, traces | ||||
| from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING | ||||
| from ultralytics.yolo.utils import LOGGER | ||||
| from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params | ||||
|  | ||||
|  | ||||
| def on_pretrain_routine_end(trainer): | ||||
|     session = not TESTS_RUNNING and getattr(trainer, 'hub_session', None) | ||||
|     session = getattr(trainer, 'hub_session', None) | ||||
|     if session: | ||||
|         # Start timer for upload rate limit | ||||
|         LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀') | ||||
|         session.t = {'metrics': time(), 'ckpt': time()}  # start timer on self.rate_limit | ||||
|         session.timers = {'metrics': time(), 'ckpt': time()}  # start timer on session.rate_limit | ||||
|  | ||||
|  | ||||
| def on_fit_epoch_end(trainer): | ||||
|     session = getattr(trainer, 'hub_session', None) | ||||
|     if session: | ||||
|         session.metrics_queue[trainer.epoch] = json.dumps(trainer.metrics)  # json string | ||||
|         if time() - session.t['metrics'] > session.rate_limits['metrics']: | ||||
|         # Upload metrics after val end | ||||
|         all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics} | ||||
|         if trainer.epoch == 0: | ||||
|             model_info = { | ||||
|                 'model/parameters': get_num_params(trainer.model), | ||||
|                 'model/GFLOPs': round(get_flops(trainer.model), 3), | ||||
|                 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)} | ||||
|             all_plots = {**all_plots, **model_info} | ||||
|         session.metrics_queue[trainer.epoch] = json.dumps(all_plots) | ||||
|         if time() - session.timers['metrics'] > session.rate_limits['metrics']: | ||||
|             session.upload_metrics() | ||||
|             session.t['metrics'] = time()  # reset timer | ||||
|             session.timers['metrics'] = time()  # reset timer | ||||
|             session.metrics_queue = {}  # reset queue | ||||
|  | ||||
|  | ||||
| @ -30,21 +39,21 @@ def on_model_save(trainer): | ||||
|     if session: | ||||
|         # Upload checkpoints with rate limiting | ||||
|         is_best = trainer.best_fitness == trainer.fitness | ||||
|         if time() - session.t['ckpt'] > session.rate_limits['ckpt']: | ||||
|         if time() - session.timers['ckpt'] > session.rate_limits['ckpt']: | ||||
|             LOGGER.info(f'{PREFIX}Uploading checkpoint {session.model_id}') | ||||
|             session.upload_model(trainer.epoch, trainer.last, is_best) | ||||
|             session.t['ckpt'] = time()  # reset timer | ||||
|             session.timers['ckpt'] = time()  # reset timer | ||||
|  | ||||
|  | ||||
| def on_train_end(trainer): | ||||
|     session = getattr(trainer, 'hub_session', None) | ||||
|     if session: | ||||
|         # Upload final model and metrics with exponential standoff | ||||
|         LOGGER.info(f'{PREFIX}Training completed successfully ✅\n' | ||||
|                     f'{PREFIX}Uploading final {session.model_id}') | ||||
|         session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics['metrics/mAP50-95(B)'], final=True) | ||||
|         session.shutdown()  # stop heartbeats | ||||
|         LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀') | ||||
|         LOGGER.info(f'{PREFIX}Syncing final model...') | ||||
|         session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True) | ||||
|         session.alive = False  # stop heartbeats | ||||
|         LOGGER.info(f'{PREFIX}Done ✅\n' | ||||
|                     f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀') | ||||
|  | ||||
|  | ||||
| def on_train_start(trainer): | ||||
|  | ||||
| @ -1,8 +1,12 @@ | ||||
| # Ultralytics YOLO 🚀, GPL-3.0 license | ||||
| from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING | ||||
|  | ||||
| from torch.utils.tensorboard import SummaryWriter | ||||
| try: | ||||
|     from torch.utils.tensorboard import SummaryWriter | ||||
|  | ||||
| from ultralytics.yolo.utils import LOGGER | ||||
|     assert not TESTS_RUNNING  # do not log pytest | ||||
| except (ImportError, AssertionError): | ||||
|     SummaryWriter = None | ||||
|  | ||||
| writer = None  # TensorBoard SummaryWriter instance | ||||
|  | ||||
| @ -18,7 +22,6 @@ def on_pretrain_routine_start(trainer): | ||||
|     try: | ||||
|         writer = SummaryWriter(str(trainer.save_dir)) | ||||
|     except Exception as e: | ||||
|         writer = None  # TensorBoard SummaryWriter instance | ||||
|         LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}') | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -21,7 +21,7 @@ import torch | ||||
| from matplotlib import font_manager | ||||
|  | ||||
| from ultralytics.yolo.utils import (AUTOINSTALL, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, downloads, emojis, | ||||
|                                     is_colab, is_docker, is_jupyter) | ||||
|                                     is_colab, is_docker, is_jupyter, is_online) | ||||
|  | ||||
|  | ||||
| def is_ascii(s) -> bool: | ||||
| @ -171,21 +171,6 @@ def check_font(font='Arial.ttf'): | ||||
|         return file | ||||
|  | ||||
|  | ||||
| def check_online() -> bool: | ||||
|     """ | ||||
|     Check internet connectivity by attempting to connect to a known online host. | ||||
|  | ||||
|     Returns: | ||||
|         bool: True if connection is successful, False otherwise. | ||||
|     """ | ||||
|     import socket | ||||
|     with contextlib.suppress(Exception): | ||||
|         host = socket.gethostbyname('www.github.com') | ||||
|         socket.create_connection((host, 80), timeout=2) | ||||
|         return True | ||||
|     return False | ||||
|  | ||||
|  | ||||
| def check_python(minimum: str = '3.7.0') -> bool: | ||||
|     """ | ||||
|     Check current python version against the required minimum version. | ||||
| @ -229,7 +214,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=() | ||||
|     if s and install and AUTOINSTALL:  # check environment variable | ||||
|         LOGGER.info(f"{prefix} YOLOv8 requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...") | ||||
|         try: | ||||
|             assert check_online(), 'AutoUpdate skipped (offline)' | ||||
|             assert is_online(), 'AutoUpdate skipped (offline)' | ||||
|             LOGGER.info(subprocess.check_output(f'pip install {s} {cmds}', shell=True).decode()) | ||||
|             s = f"{prefix} {n} package{'s' * (n > 1)} updated per {file or requirements}\n" \ | ||||
|                 f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n" | ||||
| @ -249,13 +234,13 @@ def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''): | ||||
|                 assert s in suffix, f'{msg}{f} acceptable suffix is {suffix}' | ||||
|  | ||||
|  | ||||
| def check_yolov5u_filename(file: str): | ||||
| def check_yolov5u_filename(file: str, verbose: bool = True): | ||||
|     # Replace legacy YOLOv5 filenames with updated YOLOv5u filenames | ||||
|     if 'yolov3' in file or 'yolov5' in file and 'u' not in file: | ||||
|         original_file = file | ||||
|         file = re.sub(r'(.*yolov5([nsmlx]))\.', '\\1u.', file)  # i.e. yolov5n.pt -> yolov5nu.pt | ||||
|         file = re.sub(r'(.*yolov3(|-tiny|-spp))\.', '\\1u.', file)  # i.e. yolov3-spp.pt -> yolov3-sppu.pt | ||||
|         if file != original_file: | ||||
|         if file != original_file and verbose: | ||||
|             LOGGER.info(f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are " | ||||
|                         f'trained with https://github.com/ultralytics/ultralytics and feature improved performance vs ' | ||||
|                         f'standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n') | ||||
|  | ||||
| @ -12,7 +12,7 @@ import requests | ||||
| import torch | ||||
| from tqdm import tqdm | ||||
|  | ||||
| from ultralytics.yolo.utils import LOGGER, checks | ||||
| from ultralytics.yolo.utils import LOGGER, checks, is_online | ||||
|  | ||||
| GITHUB_ASSET_NAMES = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] + \ | ||||
|                      [f'yolov5{size}u.pt' for size in 'nsmlx'] + \ | ||||
| @ -112,7 +112,7 @@ def safe_download(url, | ||||
|                         break  # success | ||||
|                     f.unlink()  # remove partial downloads | ||||
|             except Exception as e: | ||||
|                 if i == 0 and not checks.check_online(): | ||||
|                 if i == 0 and not is_online(): | ||||
|                     raise ConnectionError(f'❌  Download failure for {url}. Environment is not online.') from e | ||||
|                 elif i >= retry: | ||||
|                     raise ConnectionError(f'❌  Download failure for {url}. Retry limit reached.') from e | ||||
| @ -134,8 +134,7 @@ def safe_download(url, | ||||
|  | ||||
| def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'): | ||||
|     # Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc. | ||||
|     from ultralytics.yolo.utils import SETTINGS | ||||
|     from ultralytics.yolo.utils.checks import check_yolov5u_filename | ||||
|     from ultralytics.yolo.utils import SETTINGS  # scoped for circular import | ||||
|  | ||||
|     def github_assets(repository, version='latest'): | ||||
|         # Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov8s.pt', ...]) | ||||
| @ -146,7 +145,7 @@ def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'): | ||||
|  | ||||
|     # YOLOv3/5u updates | ||||
|     file = str(file) | ||||
|     file = check_yolov5u_filename(file) | ||||
|     file = checks.check_yolov5u_filename(file) | ||||
|     file = Path(file.strip().replace("'", '')) | ||||
|     if file.exists(): | ||||
|         return str(file) | ||||
|  | ||||
| @ -43,16 +43,18 @@ def bbox_ioa(box1, box2, eps=1e-7): | ||||
|  | ||||
|  | ||||
| def box_iou(box1, box2, eps=1e-7): | ||||
|     # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py | ||||
|     """ | ||||
|     Return intersection-over-union (Jaccard index) of boxes. | ||||
|     Both sets of boxes are expected to be in (x1, y1, x2, y2) format. | ||||
|     Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py | ||||
|  | ||||
|     Arguments: | ||||
|         box1 (Tensor[N, 4]) | ||||
|         box2 (Tensor[M, 4]) | ||||
|         eps | ||||
|  | ||||
|     Returns: | ||||
|         iou (Tensor[N, M]): the NxM matrix containing the pairwise | ||||
|             IoU values for every element in boxes1 and boxes2 | ||||
|         iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 | ||||
|     """ | ||||
|  | ||||
|     # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) | ||||
| @ -109,7 +111,7 @@ def mask_iou(mask1, mask2, eps=1e-7): | ||||
|     mask1: [N, n] m1 means number of predicted objects | ||||
|     mask2: [M, n] m2 means number of gt objects | ||||
|     Note: n means image_w x image_h | ||||
|     return: masks iou, [N, M] | ||||
|     Returns: masks iou, [N, M] | ||||
|     """ | ||||
|     intersection = torch.matmul(mask1, mask2.t()).clamp(0) | ||||
|     union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection  # (area1 + area2) - intersection | ||||
| @ -121,7 +123,7 @@ def masks_iou(mask1, mask2, eps=1e-7): | ||||
|     mask1: [N, n] m1 means number of predicted objects | ||||
|     mask2: [N, n] m2 means number of gt objects | ||||
|     Note: n means image_w x image_h | ||||
|     return: masks iou, (N, ) | ||||
|     Returns: masks iou, (N, ) | ||||
|     """ | ||||
|     intersection = (mask1 * mask2).sum(1).clamp(0)  # (N, ) | ||||
|     union = (mask1.sum(1) + mask2.sum(1))[None] - intersection  # (area1 + area2) - intersection | ||||
| @ -317,10 +319,10 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi | ||||
|  | ||||
| def compute_ap(recall, precision): | ||||
|     """ Compute the average precision, given the recall and precision curves | ||||
|     # Arguments | ||||
|     Arguments: | ||||
|         recall:    The recall curve (list) | ||||
|         precision: The precision curve (list) | ||||
|     # Returns | ||||
|     Returns: | ||||
|         Average precision, precision curve, recall curve | ||||
|     """ | ||||
|  | ||||
| @ -344,17 +346,30 @@ def compute_ap(recall, precision): | ||||
|  | ||||
|  | ||||
| def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), names=(), eps=1e-16, prefix=''): | ||||
|     """ Compute the average precision, given the recall and precision curves. | ||||
|     Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. | ||||
|     # Arguments | ||||
|         tp:  True positives (nparray, nx1 or nx10). | ||||
|         conf:  Objectness value from 0-1 (nparray). | ||||
|         pred_cls:  Predicted object classes (nparray). | ||||
|         target_cls:  True object classes (nparray). | ||||
|         plot:  Plot precision-recall curve at mAP@0.5 | ||||
|         save_dir:  Plot save directory | ||||
|     # Returns | ||||
|         The average precision as computed in py-faster-rcnn. | ||||
|     """ | ||||
|     Computes the average precision per class for object detection evaluation. | ||||
|  | ||||
|     Args: | ||||
|         tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False). | ||||
|         conf (np.ndarray): Array of confidence scores of the detections. | ||||
|         pred_cls (np.ndarray): Array of predicted classes of the detections. | ||||
|         target_cls (np.ndarray): Array of true classes of the detections. | ||||
|         plot (bool, optional): Whether to plot PR curves or not. Defaults to False. | ||||
|         save_dir (Path, optional): Directory to save the PR curves. Defaults to an empty path. | ||||
|         names (tuple, optional): Tuple of class names to plot PR curves. Defaults to an empty tuple. | ||||
|         eps (float, optional): A small value to avoid division by zero. Defaults to 1e-16. | ||||
|         prefix (str, optional): A prefix string for saving the plot files. Defaults to an empty string. | ||||
|  | ||||
|     Returns: | ||||
|         (tuple): A tuple of six arrays and one array of unique classes, where: | ||||
|             tp (np.ndarray): True positive counts for each class. | ||||
|             fp (np.ndarray): False positive counts for each class. | ||||
|             p (np.ndarray): Precision values at each confidence threshold. | ||||
|             r (np.ndarray): Recall values at each confidence threshold. | ||||
|             f1 (np.ndarray): F1-score values at each confidence threshold. | ||||
|             ap (np.ndarray): Average precision for each class at different IoU thresholds. | ||||
|             unique_classes (np.ndarray): An array of unique classes that have data. | ||||
|  | ||||
|     """ | ||||
|  | ||||
|     # Sort by objectness | ||||
| @ -411,6 +426,32 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), na | ||||
|  | ||||
|  | ||||
| class Metric: | ||||
|     """ | ||||
|         Class for computing evaluation metrics for YOLOv8 model. | ||||
|  | ||||
|         Attributes: | ||||
|             p (list): Precision for each class. Shape: (nc,). | ||||
|             r (list): Recall for each class. Shape: (nc,). | ||||
|             f1 (list): F1 score for each class. Shape: (nc,). | ||||
|             all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10). | ||||
|             ap_class_index (list): Index of class for each AP score. Shape: (nc,). | ||||
|             nc (int): Number of classes. | ||||
|  | ||||
|         Methods: | ||||
|             ap50(): AP at IoU threshold of 0.5 for all classes. Returns: List of AP scores. Shape: (nc,) or []. | ||||
|             ap(): AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: List of AP scores. Shape: (nc,) or []. | ||||
|             mp(): Mean precision of all classes. Returns: Float. | ||||
|             mr(): Mean recall of all classes. Returns: Float. | ||||
|             map50(): Mean AP at IoU threshold of 0.5 for all classes. Returns: Float. | ||||
|             map75(): Mean AP at IoU threshold of 0.75 for all classes. Returns: Float. | ||||
|             map(): Mean AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: Float. | ||||
|             mean_results(): Mean of results, returns mp, mr, map50, map. | ||||
|             class_result(i): Class-aware result, returns p[i], r[i], ap50[i], ap[i]. | ||||
|             maps(): mAP of each class. Returns: Array of mAP scores, shape: (nc,). | ||||
|             fitness(): Model fitness as a weighted combination of metrics. Returns: Float. | ||||
|             update(results): Update metric attributes with new evaluation results. | ||||
|  | ||||
|         """ | ||||
|  | ||||
|     def __init__(self) -> None: | ||||
|         self.p = []  # (nc, ) | ||||
| @ -420,10 +461,14 @@ class Metric: | ||||
|         self.ap_class_index = []  # (nc, ) | ||||
|         self.nc = 0 | ||||
|  | ||||
|     def __getattr__(self, attr): | ||||
|         name = self.__class__.__name__ | ||||
|         raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") | ||||
|  | ||||
|     @property | ||||
|     def ap50(self): | ||||
|         """AP@0.5 of all classes. | ||||
|         Return: | ||||
|         Returns: | ||||
|             (nc, ) or []. | ||||
|         """ | ||||
|         return self.all_ap[:, 0] if len(self.all_ap) else [] | ||||
| @ -431,7 +476,7 @@ class Metric: | ||||
|     @property | ||||
|     def ap(self): | ||||
|         """AP@0.5:0.95 | ||||
|         Return: | ||||
|         Returns: | ||||
|             (nc, ) or []. | ||||
|         """ | ||||
|         return self.all_ap.mean(1) if len(self.all_ap) else [] | ||||
| @ -439,7 +484,7 @@ class Metric: | ||||
|     @property | ||||
|     def mp(self): | ||||
|         """mean precision of all classes. | ||||
|         Return: | ||||
|         Returns: | ||||
|             float. | ||||
|         """ | ||||
|         return self.p.mean() if len(self.p) else 0.0 | ||||
| @ -447,7 +492,7 @@ class Metric: | ||||
|     @property | ||||
|     def mr(self): | ||||
|         """mean recall of all classes. | ||||
|         Return: | ||||
|         Returns: | ||||
|             float. | ||||
|         """ | ||||
|         return self.r.mean() if len(self.r) else 0.0 | ||||
| @ -455,7 +500,7 @@ class Metric: | ||||
|     @property | ||||
|     def map50(self): | ||||
|         """Mean AP@0.5 of all classes. | ||||
|         Return: | ||||
|         Returns: | ||||
|             float. | ||||
|         """ | ||||
|         return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0 | ||||
| @ -463,7 +508,7 @@ class Metric: | ||||
|     @property | ||||
|     def map75(self): | ||||
|         """Mean AP@0.75 of all classes. | ||||
|         Return: | ||||
|         Returns: | ||||
|             float. | ||||
|         """ | ||||
|         return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0 | ||||
| @ -471,7 +516,7 @@ class Metric: | ||||
|     @property | ||||
|     def map(self): | ||||
|         """Mean AP@0.5:0.95 of all classes. | ||||
|         Return: | ||||
|         Returns: | ||||
|             float. | ||||
|         """ | ||||
|         return self.all_ap.mean() if len(self.all_ap) else 0.0 | ||||
| @ -506,6 +551,32 @@ class Metric: | ||||
|  | ||||
|  | ||||
| class DetMetrics: | ||||
|     """ | ||||
|     This class is a utility class for computing detection metrics such as precision, recall, and mean average precision | ||||
|     (mAP) of an object detection model. | ||||
|  | ||||
|     Args: | ||||
|         save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory. | ||||
|         plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False. | ||||
|         names (tuple of str): A tuple of strings that represents the names of the classes. Defaults to an empty tuple. | ||||
|  | ||||
|     Attributes: | ||||
|         save_dir (Path): A path to the directory where the output plots will be saved. | ||||
|         plot (bool): A flag that indicates whether to plot the precision-recall curves for each class. | ||||
|         names (tuple of str): A tuple of strings that represents the names of the classes. | ||||
|         box (Metric): An instance of the Metric class for storing the results of the detection metrics. | ||||
|         speed (dict): A dictionary for storing the execution time of different parts of the detection process. | ||||
|  | ||||
|     Methods: | ||||
|         process(tp, conf, pred_cls, target_cls): Updates the metric results with the latest batch of predictions. | ||||
|         keys: Returns a list of keys for accessing the computed detection metrics. | ||||
|         mean_results: Returns a list of mean values for the computed detection metrics. | ||||
|         class_result(i): Returns a list of values for the computed detection metrics for a specific class. | ||||
|         maps: Returns a dictionary of mean average precision (mAP) values for different IoU thresholds. | ||||
|         fitness: Computes the fitness score based on the computed detection metrics. | ||||
|         ap_class_index: Returns a list of class indices sorted by their average precision (AP) values. | ||||
|         results_dict: Returns a dictionary that maps detection metric keys to their computed values. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None: | ||||
|         self.save_dir = save_dir | ||||
| @ -514,6 +585,10 @@ class DetMetrics: | ||||
|         self.box = Metric() | ||||
|         self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0} | ||||
|  | ||||
|     def __getattr__(self, attr): | ||||
|         name = self.__class__.__name__ | ||||
|         raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") | ||||
|  | ||||
|     def process(self, tp, conf, pred_cls, target_cls): | ||||
|         results = ap_per_class(tp, conf, pred_cls, target_cls, plot=self.plot, save_dir=self.save_dir, | ||||
|                                names=self.names)[2:] | ||||
| @ -548,6 +623,31 @@ class DetMetrics: | ||||
|  | ||||
|  | ||||
| class SegmentMetrics: | ||||
|     """ | ||||
|     Calculates and aggregates detection and segmentation metrics over a given set of classes. | ||||
|  | ||||
|     Args: | ||||
|         save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory. | ||||
|         plot (bool): Whether to save the detection and segmentation plots. Default is False. | ||||
|         names (list): List of class names. Default is an empty list. | ||||
|  | ||||
|     Attributes: | ||||
|         save_dir (Path): Path to the directory where the output plots should be saved. | ||||
|         plot (bool): Whether to save the detection and segmentation plots. | ||||
|         names (list): List of class names. | ||||
|         box (Metric): An instance of the Metric class to calculate box detection metrics. | ||||
|         seg (Metric): An instance of the Metric class to calculate mask segmentation metrics. | ||||
|         speed (dict): Dictionary to store the time taken in different phases of inference. | ||||
|  | ||||
|     Methods: | ||||
|         process(tp_m, tp_b, conf, pred_cls, target_cls): Processes metrics over the given set of predictions. | ||||
|         mean_results(): Returns the mean of the detection and segmentation metrics over all the classes. | ||||
|         class_result(i): Returns the detection and segmentation metrics of class `i`. | ||||
|         maps: Returns the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95. | ||||
|         fitness: Returns the fitness scores, which are a single weighted combination of metrics. | ||||
|         ap_class_index: Returns the list of indices of classes used to compute Average Precision (AP). | ||||
|         results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None: | ||||
|         self.save_dir = save_dir | ||||
| @ -557,7 +657,22 @@ class SegmentMetrics: | ||||
|         self.seg = Metric() | ||||
|         self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0} | ||||
|  | ||||
|     def __getattr__(self, attr): | ||||
|         name = self.__class__.__name__ | ||||
|         raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") | ||||
|  | ||||
|     def process(self, tp_m, tp_b, conf, pred_cls, target_cls): | ||||
|         """ | ||||
|         Processes the detection and segmentation metrics over the given set of predictions. | ||||
|  | ||||
|         Args: | ||||
|             tp_m (list): List of True Positive masks. | ||||
|             tp_b (list): List of True Positive boxes. | ||||
|             conf (list): List of confidence scores. | ||||
|             pred_cls (list): List of predicted classes. | ||||
|             target_cls (list): List of target classes. | ||||
|         """ | ||||
|  | ||||
|         results_mask = ap_per_class(tp_m, | ||||
|                                     conf, | ||||
|                                     pred_cls, | ||||
| @ -610,12 +725,32 @@ class SegmentMetrics: | ||||
|  | ||||
|  | ||||
| class ClassifyMetrics: | ||||
|     """ | ||||
|     Class for computing classification metrics including top-1 and top-5 accuracy. | ||||
|  | ||||
|     Attributes: | ||||
|         top1 (float): The top-1 accuracy. | ||||
|         top5 (float): The top-5 accuracy. | ||||
|         speed (Dict[str, float]): A dictionary containing the time taken for each step in the pipeline. | ||||
|  | ||||
|     Properties: | ||||
|         fitness (float): The fitness of the model, which is equal to top-5 accuracy. | ||||
|         results_dict (Dict[str, Union[float, str]]): A dictionary containing the classification metrics and fitness. | ||||
|         keys (List[str]): A list of keys for the results_dict. | ||||
|  | ||||
|     Methods: | ||||
|         process(targets, pred): Processes the targets and predictions to compute classification metrics. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self) -> None: | ||||
|         self.top1 = 0 | ||||
|         self.top5 = 0 | ||||
|         self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0} | ||||
|  | ||||
|     def __getattr__(self, attr): | ||||
|         name = self.__class__.__name__ | ||||
|         raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") | ||||
|  | ||||
|     def process(self, targets, pred): | ||||
|         # target classes and predicted classes | ||||
|         pred, targets = torch.cat(pred), torch.cat(targets) | ||||
|  | ||||
| @ -301,14 +301,14 @@ def plot_images(images, | ||||
|  | ||||
|             # Plot masks | ||||
|             if len(masks): | ||||
|                 if masks.max() > 1.0:  # mean that masks are overlap | ||||
|                 if idx.shape[0] == masks.shape[0]:  # overlap_masks=False | ||||
|                     image_masks = masks[idx] | ||||
|                 else:  # overlap_masks=True | ||||
|                     image_masks = masks[[i]]  # (1, 640, 640) | ||||
|                     nl = idx.sum() | ||||
|                     index = np.arange(nl).reshape(nl, 1, 1) + 1 | ||||
|                     image_masks = np.repeat(image_masks, nl, axis=0) | ||||
|                     image_masks = np.where(image_masks == index, 1.0, 0.0) | ||||
|                 else: | ||||
|                     image_masks = masks[idx] | ||||
|  | ||||
|                 im = np.asarray(annotator.im).copy() | ||||
|                 for j, box in enumerate(boxes.T.tolist()): | ||||
|  | ||||
| @ -52,7 +52,8 @@ class ClassificationPredictor(BasePredictor): | ||||
|             return log_string | ||||
|         prob = result.probs | ||||
|         # Print results | ||||
|         top5i = prob.argsort(0, descending=True)[:5].tolist()  # top 5 indices | ||||
|         n5 = min(len(self.model.names), 5) | ||||
|         top5i = prob.argsort(0, descending=True)[:n5].tolist()  # top 5 indices | ||||
|         log_string += f"{', '.join(f'{self.model.names[j]} {prob[j]:.2f}' for j in top5i)}, " | ||||
|  | ||||
|         # write | ||||
|  | ||||
| @ -27,7 +27,8 @@ class ClassificationValidator(BaseValidator): | ||||
|         return batch | ||||
|  | ||||
|     def update_metrics(self, preds, batch): | ||||
|         self.pred.append(preds.argsort(1, descending=True)[:, :5]) | ||||
|         n5 = min(len(self.model.names), 5) | ||||
|         self.pred.append(preds.argsort(1, descending=True)[:, :n5]) | ||||
|         self.targets.append(batch['cls']) | ||||
|  | ||||
|     def finalize_metrics(self, *args, **kwargs): | ||||
|  | ||||
		Reference in New Issue
	
	Block a user