ultralytics 8.0.78
Docker and confusion matrix updates (#2035)
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Abdul Manaf <75582860+AbdulManaf12@users.noreply.github.com>
This commit is contained in:
@ -174,9 +174,9 @@ class BaseValidator:
|
||||
self.run_callbacks('on_val_batch_end')
|
||||
stats = self.get_stats()
|
||||
self.check_stats(stats)
|
||||
self.print_results()
|
||||
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1E3 for x in dt)))
|
||||
self.finalize_metrics()
|
||||
self.print_results()
|
||||
self.run_callbacks('on_val_end')
|
||||
if self.training:
|
||||
model.float()
|
||||
|
@ -220,13 +220,24 @@ def set_logging(name=LOGGING_NAME, verbose=True):
|
||||
'propagate': False}}})
|
||||
|
||||
|
||||
class EmojiFilter(logging.Filter):
|
||||
"""
|
||||
A custom logging filter class for removing emojis in log messages.
|
||||
|
||||
This filter is particularly useful for ensuring compatibility with Windows terminals
|
||||
that may not support the display of emojis in log messages.
|
||||
"""
|
||||
|
||||
def filter(self, record):
|
||||
record.msg = emojis(record.msg)
|
||||
return super().filter(record)
|
||||
|
||||
|
||||
# Set logger
|
||||
set_logging(LOGGING_NAME, verbose=VERBOSE) # run before defining LOGGER
|
||||
LOGGER = logging.getLogger(LOGGING_NAME) # define globally (used in train.py, val.py, detect.py, etc.)
|
||||
if WINDOWS: # emoji-safe logging
|
||||
info_fn, warning_fn = LOGGER.info, LOGGER.warning
|
||||
setattr(LOGGER, info_fn.__name__, lambda x: info_fn(emojis(x)))
|
||||
setattr(LOGGER, warning_fn.__name__, lambda x: warning_fn(emojis(x)))
|
||||
LOGGER.addFilter(EmojiFilter())
|
||||
|
||||
|
||||
def yaml_save(file='data.yaml', data=None):
|
||||
|
@ -197,7 +197,19 @@ def check_python(minimum: str = '3.7.0') -> bool:
|
||||
|
||||
@TryExcept()
|
||||
def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=(), install=True, cmds=''):
|
||||
# Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages or single package str)
|
||||
"""
|
||||
Check if installed dependencies meet YOLOv5 requirements and attempt to auto-update if needed.
|
||||
|
||||
Args:
|
||||
requirements (Union[Path, str, List[str]]): Path to a requirements.txt file, a single package requirement as a
|
||||
string, or a list of package requirements as strings.
|
||||
exclude (Tuple[str]): Tuple of package names to exclude from checking.
|
||||
install (bool): If True, attempt to auto-update packages that don't meet requirements.
|
||||
cmds (str): Additional commands to pass to the pip install command when auto-updating.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
prefix = colorstr('red', 'bold', 'requirements:')
|
||||
check_python() # check python version
|
||||
file = None
|
||||
@ -209,8 +221,8 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
||||
elif isinstance(requirements, str):
|
||||
requirements = [requirements]
|
||||
|
||||
s = ''
|
||||
n = 0
|
||||
s = '' # console string
|
||||
n = 0 # number of packages updates
|
||||
for r in requirements:
|
||||
try:
|
||||
pkg.require(r)
|
||||
@ -226,7 +238,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
||||
LOGGER.info(f"{prefix} YOLOv8 requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
|
||||
try:
|
||||
assert is_online(), 'AutoUpdate skipped (offline)'
|
||||
LOGGER.info(subprocess.check_output(f'pip install {s} {cmds}', shell=True).decode())
|
||||
LOGGER.info(subprocess.check_output(f'pip install --no-cache {s} {cmds}', shell=True).decode())
|
||||
s = f"{prefix} {n} package{'s' * (n > 1)} updated per {file or requirements}\n" \
|
||||
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
|
||||
LOGGER.info(s)
|
||||
|
@ -12,7 +12,6 @@ class ClassificationValidator(BaseValidator):
|
||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.args.task = 'classify'
|
||||
self.metrics = ClassifyMetrics()
|
||||
self.save_dir = save_dir
|
||||
|
||||
def get_desc(self):
|
||||
return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc')
|
||||
@ -37,6 +36,8 @@ class ClassificationValidator(BaseValidator):
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
|
||||
self.metrics.speed = self.speed
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
|
||||
@ -55,8 +56,6 @@ class ClassificationValidator(BaseValidator):
|
||||
def print_results(self):
|
||||
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
|
||||
LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5))
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
|
||||
|
||||
|
||||
def val(cfg=DEFAULT_CFG, use_python=False):
|
||||
|
Reference in New Issue
Block a user