ultralytics 8.0.80
single-line docstring fixes (#2060)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -46,7 +46,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
"""
|
||||
load/create/download model for any task
|
||||
"""
|
||||
# classification models require special handling
|
||||
# Classification models require special handling
|
||||
|
||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||
return
|
||||
|
@ -22,8 +22,8 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
|
||||
class DetectionTrainer(BaseTrainer):
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size, rank=0, mode='train'):
|
||||
# TODO: manage splits differently
|
||||
# calculate stride - check if model is initialized
|
||||
"""TODO: manage splits differently."""
|
||||
# Calculate stride - check if model is initialized
|
||||
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||
return create_dataloader(path=dataset_path,
|
||||
imgsz=self.args.imgsz,
|
||||
@ -48,7 +48,7 @@ class DetectionTrainer(BaseTrainer):
|
||||
return batch
|
||||
|
||||
def set_model_attributes(self):
|
||||
# nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)
|
||||
"""nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)."""
|
||||
# self.args.box *= 3 / nl # scale to layers
|
||||
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
|
||||
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
|
||||
|
@ -67,7 +67,7 @@ class DetectionValidator(BaseValidator):
|
||||
return preds
|
||||
|
||||
def update_metrics(self, preds, batch):
|
||||
# Metrics
|
||||
"""Metrics."""
|
||||
for si, pred in enumerate(preds):
|
||||
idx = batch['batch_idx'] == si
|
||||
cls = batch['cls'][idx]
|
||||
@ -164,8 +164,8 @@ class DetectionValidator(BaseValidator):
|
||||
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
# TODO: manage splits differently
|
||||
# calculate stride - check if model is initialized
|
||||
"""TODO: manage splits differently."""
|
||||
# Calculate stride - check if model is initialized
|
||||
gs = max(int(de_parallel(self.model).stride if self.model else 0), 32)
|
||||
return create_dataloader(path=dataset_path,
|
||||
imgsz=self.args.imgsz,
|
||||
|
@ -47,7 +47,7 @@ class PoseValidator(DetectionValidator):
|
||||
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
|
||||
|
||||
def update_metrics(self, preds, batch):
|
||||
# Metrics
|
||||
"""Metrics."""
|
||||
for si, pred in enumerate(preds):
|
||||
idx = batch['batch_idx'] == si
|
||||
cls = batch['cls'][idx]
|
||||
|
@ -10,7 +10,7 @@ from ultralytics.yolo.v8.detect.predict import DetectionPredictor
|
||||
class SegmentationPredictor(DetectionPredictor):
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
# TODO: filter by classes
|
||||
"""TODO: filter by classes."""
|
||||
p = ops.non_max_suppression(preds[0],
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
|
@ -140,7 +140,7 @@ class SegLoss(Loss):
|
||||
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
|
||||
|
||||
def single_mask_loss(self, gt_mask, pred, proto, xyxy, area):
|
||||
# Mask loss for one image
|
||||
"""Mask loss for one image."""
|
||||
pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80)
|
||||
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
|
||||
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
|
||||
|
@ -52,7 +52,7 @@ class SegmentationValidator(DetectionValidator):
|
||||
return p, proto
|
||||
|
||||
def update_metrics(self, preds, batch):
|
||||
# Metrics
|
||||
"""Metrics."""
|
||||
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
|
||||
idx = batch['batch_idx'] == si
|
||||
cls = batch['cls'][idx]
|
||||
@ -179,7 +179,7 @@ class SegmentationValidator(DetectionValidator):
|
||||
self.plot_masks.clear()
|
||||
|
||||
def pred_to_json(self, predn, filename, pred_masks):
|
||||
# Save one JSON result
|
||||
"""Save one JSON result."""
|
||||
# Example result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
|
||||
from pycocotools.mask import encode # noqa
|
||||
|
||||
|
Reference in New Issue
Block a user