standalone val (#56)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2022-11-30 15:04:44 +05:30
committed by GitHub
parent 3a241e4cea
commit 5a52e7663a
16 changed files with 161 additions and 31 deletions

View File

@ -113,8 +113,8 @@ def get_model(model='s.pt', pretrained=True):
model = model.split(".")[0]
if Path(f"{model}.pt").is_file(): # local file
return torch.load(f"{model}.pt", map_location='cpu')
return attempt_load_weights(f"{model}.pt", device='cpu')
elif model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0
return torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
else: # Ultralytics assets
return torch.load(attempt_download(f"{model}.pt"), map_location='cpu')
return attempt_load_weights(f"{model}.pt", device='cpu')

View File

@ -304,7 +304,7 @@ class AutoBackend(nn.Module):
def _model_type(p='path/to/model.pt'):
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
from export import export_formats
from ultralytics.yolo.engine.exporter import export_formats
sf = list(export_formats().Suffix) # export suffixes
if not is_url(p, check=False):
check_suffix(p, sf) # checks

View File

@ -172,7 +172,7 @@ class DetectionModel(BaseModel):
csd = weights['model'].float().state_dict() # checkpoint state_dict as FP32
csd = intersect_state_dicts(csd, self.state_dict()) # intersect
self.load_state_dict(csd, strict=False) # load
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from {weights}')
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
class SegmentationModel(DetectionModel):