Avoid CUDA round-trip for relevant export formats (#3727)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -213,7 +213,6 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
||||
prefix = colorstr('red', 'bold', 'requirements:')
|
||||
check_python() # check python version
|
||||
check_torchvision() # check torch-torchvision compatibility
|
||||
file = None
|
||||
if isinstance(requirements, Path): # requirements.txt file
|
||||
file = requirements.resolve()
|
||||
assert file.exists(), f'{prefix} {file} not found, check failed.'
|
||||
@ -225,13 +224,13 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
||||
s = '' # console string
|
||||
pkgs = []
|
||||
for r in requirements:
|
||||
rmin = r.split('/')[-1].replace('.git', '') # replace git+https://org/repo.git -> 'repo'
|
||||
r_stripped = r.split('/')[-1].replace('.git', '') # replace git+https://org/repo.git -> 'repo'
|
||||
try:
|
||||
pkg.require(rmin)
|
||||
pkg.require(r_stripped)
|
||||
except (pkg.VersionConflict, pkg.DistributionNotFound): # exception if requirements not met
|
||||
try: # attempt to import (slower but more accurate)
|
||||
import importlib
|
||||
importlib.import_module(next(pkg.parse_requirements(rmin)).name)
|
||||
importlib.import_module(next(pkg.parse_requirements(r_stripped)).name)
|
||||
except ImportError:
|
||||
s += f'"{r}" '
|
||||
pkgs.append(r)
|
||||
|
Reference in New Issue
Block a user