Metrics and loss structure (#28)

Co-authored-by: Ayush Chaurasia <ayush.chuararsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ayush Chaurasia
2022-10-15 23:09:05 +05:30
committed by GitHub
parent d0b3c9812b
commit c5cb76b356
12 changed files with 183 additions and 43 deletions

View File

@ -1,4 +1,4 @@
from .general import WorkingDirectory, check_version, download, increment_path, save_yaml
from .general import Profile, WorkingDirectory, check_version, download, increment_path, save_yaml
from .torch_utils import LOCAL_RANK, RANK, WORLD_SIZE, DDP_model, select_device, torch_distributed_zero_first
__all__ = [
@ -8,6 +8,7 @@ __all__ = [
"WorkingDirectory",
"download",
"check_version",
"Profile",
# torch
"torch_distributed_zero_first",
"LOCAL_RANK",

View File

@ -1,3 +1,5 @@
model: null
data: null
train:
epochs: 300
batch_size: 16

View File

@ -5,6 +5,7 @@ import logging
import os
import platform
import subprocess
import time
import urllib
from itertools import repeat
from multiprocessing.pool import ThreadPool
@ -208,7 +209,7 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
return path
def save_yaml(file='data.yaml', data={}):
def save_yaml(file='data.yaml', data=None):
# Single-line safe yaml saving
with open(file, 'w') as f:
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
@ -278,7 +279,6 @@ class WorkingDirectory(contextlib.ContextDecorator):
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
from utils.general import LOGGER
file = Path(file)
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
@ -301,7 +301,6 @@ def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
from utils.general import LOGGER
def github_assets(repository, version='latest'):
# Return GitHub repo tag and assets (i.e. ['yolov5s.pt', 'yolov5m.pt', ...])
@ -351,3 +350,23 @@ def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
def get_model(model: str):
# check for local weights
pass
class Profile(contextlib.ContextDecorator):
# YOLOv5 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
def __init__(self, t=0.0):
self.t = t
self.cuda = torch.cuda.is_available()
def __enter__(self):
self.start = self.time()
return self
def __exit__(self, type, value, traceback):
self.dt = self.time() - self.start # delta-time
self.t += self.dt # accumulate dt
def time(self):
if self.cuda:
torch.cuda.synchronize()
return time.time()

View File

@ -2,6 +2,7 @@
"""
Model validation metrics
"""
import numpy as np