Update .pre-commit-config.yaml (#1026)

This commit is contained in:
Glenn Jocher
2023-02-17 22:26:40 +01:00
committed by GitHub
parent 9047d737f4
commit edd3ff1669
76 changed files with 928 additions and 935 deletions

View File

@ -21,14 +21,14 @@ class ClassificationValidator(BaseValidator):
self.targets = []
def preprocess(self, batch):
batch["img"] = batch["img"].to(self.device, non_blocking=True)
batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
batch["cls"] = batch["cls"].to(self.device)
batch['img'] = batch['img'].to(self.device, non_blocking=True)
batch['img'] = batch['img'].half() if self.args.half else batch['img'].float()
batch['cls'] = batch['cls'].to(self.device)
return batch
def update_metrics(self, preds, batch):
self.pred.append(preds.argsort(1, descending=True)[:, :5])
self.targets.append(batch["cls"])
self.targets.append(batch['cls'])
def get_stats(self):
self.metrics.process(self.targets, self.pred)
@ -42,12 +42,12 @@ class ClassificationValidator(BaseValidator):
def print_results(self):
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
self.logger.info(pf % ("all", self.metrics.top1, self.metrics.top5))
self.logger.info(pf % ('all', self.metrics.top1, self.metrics.top5))
def val(cfg=DEFAULT_CFG, use_python=False):
model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
data = cfg.data or "mnist160"
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
data = cfg.data or 'mnist160'
args = dict(model=model, data=data)
if use_python:
@ -58,5 +58,5 @@ def val(cfg=DEFAULT_CFG, use_python=False):
validator(model=args['model'])
if __name__ == "__main__":
if __name__ == '__main__':
val()