Fix Classification train logging (#157)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Glenn Jocher
2023-01-08 17:52:44 +01:00
committed by GitHub
parent d387359f74
commit e79ea1666c
7 changed files with 86 additions and 40 deletions

View File

@ -259,6 +259,7 @@ class ClassificationModel(BaseModel):
self.yaml['nc'] = nc # override yaml value
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch], verbose=verbose) # model, savelist
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
self.info()
def load(self, weights):
model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
@ -292,7 +293,6 @@ class ClassificationModel(BaseModel):
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
LOGGER.info("WARNING: Deprecated in favor of attempt_load_one_weight()")
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
from ultralytics.yolo.utils.downloads import attempt_download