new check_dataset functions (#43)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -7,7 +7,7 @@ from pathlib import Path
|
||||
|
||||
# Constants
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[2] # YOLOv5 root directory
|
||||
ROOT = FILE.parents[2] # YOLO
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory
|
||||
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
|
||||
|
@ -116,13 +116,10 @@ def check_file(file, suffix=''):
|
||||
torch.hub.download_url_to_file(url, file)
|
||||
assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
|
||||
return file
|
||||
elif file.startswith('clearml://'): # ClearML Dataset ID
|
||||
assert 'clearml' in sys.modules, "Can not use ClearML dataset. Run 'pip install clearml' to install"
|
||||
return file
|
||||
else: # search
|
||||
files = []
|
||||
for d in 'data', 'models', 'utils': # search directories
|
||||
files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
|
||||
for d in 'data', 'v8', 'utils': # search directories
|
||||
files.extend(glob.glob(str(ROOT / "yolo" / d / '**' / file), recursive=True)) # find file
|
||||
assert len(files), f'File not found: {file}' # assert file was found
|
||||
assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
|
||||
return files[0] # return file
|
||||
|
@ -1,6 +1,7 @@
|
||||
import contextlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from zipfile import ZipFile
|
||||
|
||||
import yaml
|
||||
|
||||
@ -44,3 +45,19 @@ def save_yaml(file='data.yaml', data=None):
|
||||
# Single-line safe yaml saving
|
||||
with open(file, 'w') as f:
|
||||
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
|
||||
|
||||
|
||||
def yaml_load(file='data.yaml'):
|
||||
# Single-line safe yaml loading
|
||||
with open(file, errors='ignore') as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')):
|
||||
# Unzip a *.zip file to path/, excluding files containing strings in exclude list
|
||||
if path is None:
|
||||
path = Path(file).parent # default path
|
||||
with ZipFile(file) as zipObj:
|
||||
for f in zipObj.namelist(): # list all archived filenames in the zip
|
||||
if all(x not in f for x in exclude):
|
||||
zipObj.extract(f, path=path)
|
||||
|
@ -118,9 +118,3 @@ def get_model(model='s.pt', pretrained=True):
|
||||
return torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
|
||||
else: # Ultralytics assets
|
||||
return torch.load(attempt_download(f"{model}.pt"), map_location='cpu')
|
||||
|
||||
|
||||
def yaml_load(file='data.yaml'):
|
||||
# Single-line safe yaml loading
|
||||
with open(file, errors='ignore') as f:
|
||||
return yaml.safe_load(f)
|
||||
|
@ -32,7 +32,8 @@ class AutoBackend(nn.Module):
|
||||
# TensorFlow Lite: *.tflite
|
||||
# TensorFlow Edge TPU: *_edgetpu.tflite
|
||||
# PaddlePaddle: *_paddle_model
|
||||
from ultralytics.yolo.utils.modeling import attempt_load_weights, yaml_load
|
||||
from ultralytics.yolo.utils.files import yaml_load
|
||||
from ultralytics.yolo.utils.modeling import attempt_load_weights
|
||||
|
||||
super().__init__()
|
||||
w = str(weights[0] if isinstance(weights, list) else weights)
|
||||
@ -315,7 +316,7 @@ class AutoBackend(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def _load_metadata(f=Path('path/to/meta.yaml')):
|
||||
from ultralytics.yolo.utils.modeling import yaml_load
|
||||
from ultralytics.yolo.utils.files import yaml_load
|
||||
|
||||
# Load metadata from meta.yaml if it exists
|
||||
if f.exists():
|
||||
|
Reference in New Issue
Block a user