Avoid CUDA round-trip for relevant export formats (#3727)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-07-14 20:38:31 +02:00
committed by GitHub
parent c5991d7cd8
commit 135a10f1fa
5 changed files with 40 additions and 32 deletions

View File

@ -213,7 +213,6 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
prefix = colorstr('red', 'bold', 'requirements:')
check_python() # check python version
check_torchvision() # check torch-torchvision compatibility
file = None
if isinstance(requirements, Path): # requirements.txt file
file = requirements.resolve()
assert file.exists(), f'{prefix} {file} not found, check failed.'
@ -225,13 +224,13 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
s = '' # console string
pkgs = []
for r in requirements:
rmin = r.split('/')[-1].replace('.git', '') # replace git+https://org/repo.git -> 'repo'
r_stripped = r.split('/')[-1].replace('.git', '') # replace git+https://org/repo.git -> 'repo'
try:
pkg.require(rmin)
pkg.require(r_stripped)
except (pkg.VersionConflict, pkg.DistributionNotFound): # exception if requirements not met
try: # attempt to import (slower but more accurate)
import importlib
importlib.import_module(next(pkg.parse_requirements(rmin)).name)
importlib.import_module(next(pkg.parse_requirements(r_stripped)).name)
except ImportError:
s += f'"{r}" '
pkgs.append(r)