Metrics and loss structure (#28)
Co-authored-by: Ayush Chaurasia <ayush.chuararsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
from .general import WorkingDirectory, check_version, download, increment_path, save_yaml
|
||||
from .general import Profile, WorkingDirectory, check_version, download, increment_path, save_yaml
|
||||
from .torch_utils import LOCAL_RANK, RANK, WORLD_SIZE, DDP_model, select_device, torch_distributed_zero_first
|
||||
|
||||
__all__ = [
|
||||
@ -8,6 +8,7 @@ __all__ = [
|
||||
"WorkingDirectory",
|
||||
"download",
|
||||
"check_version",
|
||||
"Profile",
|
||||
# torch
|
||||
"torch_distributed_zero_first",
|
||||
"LOCAL_RANK",
|
||||
|
@ -1,3 +1,5 @@
|
||||
model: null
|
||||
data: null
|
||||
train:
|
||||
epochs: 300
|
||||
batch_size: 16
|
||||
|
@ -5,6 +5,7 @@ import logging
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import time
|
||||
import urllib
|
||||
from itertools import repeat
|
||||
from multiprocessing.pool import ThreadPool
|
||||
@ -208,7 +209,7 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
|
||||
return path
|
||||
|
||||
|
||||
def save_yaml(file='data.yaml', data={}):
|
||||
def save_yaml(file='data.yaml', data=None):
|
||||
# Single-line safe yaml saving
|
||||
with open(file, 'w') as f:
|
||||
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
|
||||
@ -278,7 +279,6 @@ class WorkingDirectory(contextlib.ContextDecorator):
|
||||
|
||||
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
|
||||
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
|
||||
from utils.general import LOGGER
|
||||
|
||||
file = Path(file)
|
||||
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
|
||||
@ -301,7 +301,6 @@ def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
|
||||
|
||||
def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
|
||||
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
|
||||
from utils.general import LOGGER
|
||||
|
||||
def github_assets(repository, version='latest'):
|
||||
# Return GitHub repo tag and assets (i.e. ['yolov5s.pt', 'yolov5m.pt', ...])
|
||||
@ -351,3 +350,23 @@ def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
|
||||
def get_model(model: str):
|
||||
# check for local weights
|
||||
pass
|
||||
|
||||
|
||||
class Profile(contextlib.ContextDecorator):
|
||||
# YOLOv5 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
|
||||
def __init__(self, t=0.0):
|
||||
self.t = t
|
||||
self.cuda = torch.cuda.is_available()
|
||||
|
||||
def __enter__(self):
|
||||
self.start = self.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.dt = self.time() - self.start # delta-time
|
||||
self.t += self.dt # accumulate dt
|
||||
|
||||
def time(self):
|
||||
if self.cuda:
|
||||
torch.cuda.synchronize()
|
||||
return time.time()
|
||||
|
@ -2,6 +2,7 @@
|
||||
"""
|
||||
Model validation metrics
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user