ultralytics 8.0.80 single-line docstring fixes (#2060)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-04-16 15:20:11 +02:00
committed by GitHub
parent 31db8ed163
commit 5bce1c3021
48 changed files with 418 additions and 420 deletions

View File

@ -46,7 +46,7 @@ class ClassificationTrainer(BaseTrainer):
"""
load/create/download model for any task
"""
# classification models require special handling
# Classification models require special handling
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return

View File

@ -22,8 +22,8 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
class DetectionTrainer(BaseTrainer):
def get_dataloader(self, dataset_path, batch_size, rank=0, mode='train'):
# TODO: manage splits differently
# calculate stride - check if model is initialized
"""TODO: manage splits differently."""
# Calculate stride - check if model is initialized
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
return create_dataloader(path=dataset_path,
imgsz=self.args.imgsz,
@ -48,7 +48,7 @@ class DetectionTrainer(BaseTrainer):
return batch
def set_model_attributes(self):
# nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)
"""nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)."""
# self.args.box *= 3 / nl # scale to layers
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers

View File

@ -67,7 +67,7 @@ class DetectionValidator(BaseValidator):
return preds
def update_metrics(self, preds, batch):
# Metrics
"""Metrics."""
for si, pred in enumerate(preds):
idx = batch['batch_idx'] == si
cls = batch['cls'][idx]
@ -164,8 +164,8 @@ class DetectionValidator(BaseValidator):
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
def get_dataloader(self, dataset_path, batch_size):
# TODO: manage splits differently
# calculate stride - check if model is initialized
"""TODO: manage splits differently."""
# Calculate stride - check if model is initialized
gs = max(int(de_parallel(self.model).stride if self.model else 0), 32)
return create_dataloader(path=dataset_path,
imgsz=self.args.imgsz,

View File

@ -47,7 +47,7 @@ class PoseValidator(DetectionValidator):
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
def update_metrics(self, preds, batch):
# Metrics
"""Metrics."""
for si, pred in enumerate(preds):
idx = batch['batch_idx'] == si
cls = batch['cls'][idx]

View File

@ -10,7 +10,7 @@ from ultralytics.yolo.v8.detect.predict import DetectionPredictor
class SegmentationPredictor(DetectionPredictor):
def postprocess(self, preds, img, orig_imgs):
# TODO: filter by classes
"""TODO: filter by classes."""
p = ops.non_max_suppression(preds[0],
self.args.conf,
self.args.iou,

View File

@ -140,7 +140,7 @@ class SegLoss(Loss):
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
def single_mask_loss(self, gt_mask, pred, proto, xyxy, area):
# Mask loss for one image
"""Mask loss for one image."""
pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80)
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()

View File

@ -52,7 +52,7 @@ class SegmentationValidator(DetectionValidator):
return p, proto
def update_metrics(self, preds, batch):
# Metrics
"""Metrics."""
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
idx = batch['batch_idx'] == si
cls = batch['cls'][idx]
@ -179,7 +179,7 @@ class SegmentationValidator(DetectionValidator):
self.plot_masks.clear()
def pred_to_json(self, predn, filename, pred_masks):
# Save one JSON result
"""Save one JSON result."""
# Example result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
from pycocotools.mask import encode # noqa