ultralytics 8.0.42
DDP fix and Docs updates (#1065)
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Noobtoss <96134731+Noobtoss@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
@ -1136,11 +1136,11 @@ class HUBDatasetStats():
|
||||
# Save, print and return
|
||||
if save:
|
||||
stats_path = self.hub_dir / 'stats.json'
|
||||
print(f'Saving {stats_path.resolve()}...')
|
||||
LOGGER.info(f'Saving {stats_path.resolve()}...')
|
||||
with open(stats_path, 'w') as f:
|
||||
json.dump(self.stats, f) # save stats.json
|
||||
if verbose:
|
||||
print(json.dumps(self.stats, indent=2, sort_keys=False))
|
||||
LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
|
||||
return self.stats
|
||||
|
||||
def process_images(self):
|
||||
@ -1154,7 +1154,7 @@ class HUBDatasetStats():
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=total, desc=desc):
|
||||
pass
|
||||
print(f'Done. All images saved to {self.im_dir}')
|
||||
LOGGER.info(f'Done. All images saved to {self.im_dir}')
|
||||
return self.im_dir
|
||||
|
||||
|
||||
|
@ -75,7 +75,6 @@ from ultralytics.yolo.utils.files import file_size
|
||||
from ultralytics.yolo.utils.ops import Profile
|
||||
from ultralytics.yolo.utils.torch_utils import get_latest_opset, select_device, smart_inference_mode
|
||||
|
||||
CUDA = torch.cuda.is_available()
|
||||
ARM64 = platform.machine() in ('arm64', 'aarch64')
|
||||
|
||||
|
||||
@ -324,7 +323,7 @@ class Exporter:
|
||||
# Simplify
|
||||
if self.args.simplify:
|
||||
try:
|
||||
check_requirements(('onnxsim', 'onnxruntime-gpu' if CUDA else 'onnxruntime'))
|
||||
check_requirements(('onnxsim', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'))
|
||||
import onnxsim
|
||||
|
||||
LOGGER.info(f'{prefix} simplifying with onnxsim {onnxsim.__version__}...')
|
||||
@ -506,10 +505,12 @@ class Exporter:
|
||||
try:
|
||||
import tensorflow as tf # noqa
|
||||
except ImportError:
|
||||
check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if CUDA else '-cpu'}")
|
||||
check_requirements(
|
||||
f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if torch.cuda.is_available() else '-cpu'}"
|
||||
)
|
||||
import tensorflow as tf # noqa
|
||||
check_requirements(('onnx', 'onnx2tf', 'sng4onnx', 'onnxsim', 'onnx_graphsurgeon', 'tflite_support',
|
||||
'onnxruntime-gpu' if CUDA else 'onnxruntime'),
|
||||
'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'),
|
||||
cmds='--extra-index-url https://pypi.ngc.nvidia.com')
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||
|
@ -32,7 +32,7 @@ class YOLO:
|
||||
YOLO (You Only Look Once) object detection model.
|
||||
|
||||
Args:
|
||||
model (str or Path): Path to the model file to load or create.
|
||||
model (str, Path): Path to the model file to load or create.
|
||||
type (str): Type/version of models to use. Defaults to "v8".
|
||||
|
||||
Attributes:
|
||||
@ -62,7 +62,7 @@ class YOLO:
|
||||
predict(source=None, stream=False, **kwargs): Perform prediction using the YOLO model.
|
||||
|
||||
Returns:
|
||||
List[ultralytics.yolo.engine.results.Results]: The prediction results.
|
||||
list(ultralytics.yolo.engine.results.Results): The prediction results.
|
||||
"""
|
||||
|
||||
def __init__(self, model='yolov8n.pt', type='v8') -> None:
|
||||
@ -114,6 +114,7 @@ class YOLO:
|
||||
self.task = guess_model_task(cfg_dict)
|
||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
|
||||
self.model = self.ModelClass(cfg_dict, verbose=verbose and RANK == -1) # initialize
|
||||
self.overrides['model'] = self.cfg
|
||||
|
||||
def _load(self, weights: str):
|
||||
"""
|
||||
@ -204,7 +205,7 @@ class YOLO:
|
||||
def track(self, source=None, stream=False, **kwargs):
|
||||
from ultralytics.tracker.track import register_tracker
|
||||
register_tracker(self)
|
||||
# bytetrack-based method needs low confidence predictions as input
|
||||
# ByteTrack-based method needs low confidence predictions as input
|
||||
conf = kwargs.get('conf') or 0.1
|
||||
kwargs['conf'] = conf
|
||||
kwargs['mode'] = 'track'
|
||||
|
@ -92,6 +92,7 @@ class BasePredictor:
|
||||
self.annotator = None
|
||||
self.data_path = None
|
||||
self.source_type = None
|
||||
self.batch = None
|
||||
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
|
@ -28,13 +28,14 @@ class Results:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, boxes=None, masks=None, probs=None, orig_img=None, names=None) -> None:
|
||||
def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None) -> None:
|
||||
self.orig_img = orig_img
|
||||
self.orig_shape = orig_img.shape[:2]
|
||||
self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes
|
||||
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
|
||||
self.probs = probs if probs is not None else None
|
||||
self.names = names
|
||||
self.path = path
|
||||
self.comp = ['boxes', 'masks', 'probs']
|
||||
|
||||
def pandas(self):
|
||||
@ -42,7 +43,7 @@ class Results:
|
||||
# TODO masks.pandas + boxes.pandas + cls.pandas
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r = Results(orig_img=self.orig_img)
|
||||
r = Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
@ -58,7 +59,7 @@ class Results:
|
||||
self.probs = probs
|
||||
|
||||
def cpu(self):
|
||||
r = Results(orig_img=self.orig_img)
|
||||
r = Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
@ -66,7 +67,7 @@ class Results:
|
||||
return r
|
||||
|
||||
def numpy(self):
|
||||
r = Results(orig_img=self.orig_img)
|
||||
r = Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
@ -74,7 +75,7 @@ class Results:
|
||||
return r
|
||||
|
||||
def cuda(self):
|
||||
r = Results(orig_img=self.orig_img)
|
||||
r = Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
@ -82,7 +83,7 @@ class Results:
|
||||
return r
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
r = Results(orig_img=self.orig_img)
|
||||
r = Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
@ -123,7 +124,7 @@ class Results:
|
||||
orig_shape (tuple, optional): Original image size.
|
||||
""")
|
||||
|
||||
def visualize(self, show_conf=True, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
||||
def plot(self, show_conf=True, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
||||
"""
|
||||
Plots the given result on an input RGB image. Accepts cv2(numpy) or PIL Image
|
||||
|
||||
@ -146,9 +147,9 @@ class Results:
|
||||
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
||||
|
||||
if masks is not None:
|
||||
im_gpu = torch.as_tensor(img, dtype=torch.float16).permute(2, 0, 1).flip(0).contiguous()
|
||||
im_gpu = F.resize(im_gpu, masks.data.shape[1:]) / 255
|
||||
annotator.masks(masks.data, colors=[colors(x, True) for x in boxes.cls], im_gpu=im_gpu)
|
||||
im = torch.as_tensor(img, dtype=torch.float16, device=masks.data.device).permute(2, 0, 1).flip(0)
|
||||
im = F.resize(im.contiguous(), masks.data.shape[1:]) / 255
|
||||
annotator.masks(masks.data, colors=[colors(x, True) for x in boxes.cls], im_gpu=im)
|
||||
|
||||
if logits is not None:
|
||||
top5i = logits.argsort(0, descending=True)[:5].tolist() # top 5 indices
|
||||
@ -371,24 +372,3 @@ class Masks:
|
||||
Properties:
|
||||
segments (list): A list of segments which includes x,y,w,h,label,confidence, and mask of each detection masks.
|
||||
""")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test examples
|
||||
results = Results(boxes=torch.randn((2, 6)), masks=torch.randn((2, 160, 160)), orig_shape=[640, 640])
|
||||
results = results.cuda()
|
||||
print('--cuda--pass--')
|
||||
results = results.cpu()
|
||||
print('--cpu--pass--')
|
||||
results = results.to('cuda:0')
|
||||
print('--to-cuda--pass--')
|
||||
results = results.to('cpu')
|
||||
print('--to-cpu--pass--')
|
||||
results = results.numpy()
|
||||
print('--numpy--pass--')
|
||||
# box = Boxes(boxes=torch.randn((2, 6)), orig_shape=[5, 5])
|
||||
# box = box.cuda()
|
||||
# box = box.cpu()
|
||||
# box = box.numpy()
|
||||
# for b in box:
|
||||
# print(b)
|
||||
|
@ -11,7 +11,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ultralytics.yolo.utils import TryExcept
|
||||
from ultralytics.yolo.utils import LOGGER, TryExcept
|
||||
|
||||
|
||||
# boxes
|
||||
@ -260,7 +260,7 @@ class ConfusionMatrix:
|
||||
|
||||
def print(self):
|
||||
for i in range(self.nc + 1):
|
||||
print(' '.join(map(str, self.matrix[i])))
|
||||
LOGGER.info(' '.join(map(str, self.matrix[i])))
|
||||
|
||||
|
||||
def smooth(y, f=0.05):
|
||||
|
@ -12,7 +12,7 @@ import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from PIL import __version__ as pil_version
|
||||
|
||||
from ultralytics.yolo.utils import threaded
|
||||
from ultralytics.yolo.utils import LOGGER, threaded
|
||||
|
||||
from .checks import check_font, check_version, is_ascii
|
||||
from .files import increment_path
|
||||
@ -300,7 +300,7 @@ def plot_results(file='path/to/results.csv', dir='', segment=False):
|
||||
# if j in [8, 9, 10]: # share train and val loss y axes
|
||||
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
|
||||
except Exception as e:
|
||||
print(f'Warning: Plotting error for {f}: {e}')
|
||||
LOGGER.warning(f'WARNING: Plotting error for {f}: {e}')
|
||||
ax[1].legend()
|
||||
fig.savefig(save_dir / 'results.png', dpi=200)
|
||||
plt.close()
|
||||
|
@ -167,11 +167,12 @@ def model_info(model, verbose=False, imgsz=640):
|
||||
n_p = get_num_params(model)
|
||||
n_g = get_num_gradients(model) # number gradients
|
||||
if verbose:
|
||||
print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
|
||||
LOGGER.info(
|
||||
f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
|
||||
for i, (name, p) in enumerate(model.named_parameters()):
|
||||
name = name.replace('module_list.', '')
|
||||
print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
|
||||
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
|
||||
LOGGER.info('%5g %40s %9s %12g %20s %10.3g %10.3g' %
|
||||
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
|
||||
|
||||
flops = get_flops(model, imgsz)
|
||||
fused = ' (fused)' if model.is_fused() else ''
|
||||
@ -362,8 +363,8 @@ def profile(input, ops, n=10, device=None):
|
||||
results = []
|
||||
if not isinstance(device, torch.device):
|
||||
device = select_device(device)
|
||||
print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
|
||||
f"{'input':>24s}{'output':>24s}")
|
||||
LOGGER.info(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
|
||||
f"{'input':>24s}{'output':>24s}")
|
||||
|
||||
for x in input if isinstance(input, list) else [input]:
|
||||
x = x.to(device)
|
||||
@ -393,10 +394,10 @@ def profile(input, ops, n=10, device=None):
|
||||
mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
|
||||
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes
|
||||
p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
|
||||
print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
|
||||
LOGGER.info(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
|
||||
results.append([p, flops, mem, tf, tb, s_in, s_out])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
LOGGER.info(e)
|
||||
results.append(None)
|
||||
torch.cuda.empty_cache()
|
||||
return results
|
||||
|
@ -22,7 +22,9 @@ class ClassificationPredictor(BasePredictor):
|
||||
results = []
|
||||
for i, pred in enumerate(preds):
|
||||
orig_img = orig_img[i] if isinstance(orig_img, list) else orig_img
|
||||
results.append(Results(probs=pred, orig_img=orig_img, names=self.model.names))
|
||||
path, _, _, _, _ = self.batch
|
||||
img_path = path[i] if isinstance(path, list) else path
|
||||
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred))
|
||||
|
||||
return results
|
||||
|
||||
|
@ -32,7 +32,9 @@ class DetectionPredictor(BasePredictor):
|
||||
orig_img = orig_img[i] if isinstance(orig_img, list) else orig_img
|
||||
shape = orig_img.shape
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
|
||||
results.append(Results(boxes=pred, orig_img=orig_img, names=self.model.names))
|
||||
path, _, _, _, _ = self.batch
|
||||
img_path = path[i] if isinstance(path, list) else path
|
||||
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
|
||||
return results
|
||||
|
||||
def write_results(self, idx, results, batch):
|
||||
|
@ -24,9 +24,10 @@ class SegmentationPredictor(DetectionPredictor):
|
||||
for i, pred in enumerate(p):
|
||||
orig_img = orig_img[i] if isinstance(orig_img, list) else orig_img
|
||||
shape = orig_img.shape
|
||||
if not len(pred):
|
||||
results.append(Results(boxes=pred[:, :6], orig_img=orig_img,
|
||||
names=self.model.names)) # save empty boxes
|
||||
path, _, _, _, _ = self.batch
|
||||
img_path = path[i] if isinstance(path, list) else path
|
||||
if not len(pred): # save empty boxes
|
||||
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
|
||||
continue
|
||||
if self.args.retina_masks:
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
|
||||
@ -34,7 +35,8 @@ class SegmentationPredictor(DetectionPredictor):
|
||||
else:
|
||||
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
|
||||
results.append(Results(boxes=pred[:, :6], masks=masks, orig_img=orig_img, names=self.model.names))
|
||||
results.append(
|
||||
Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
|
||||
return results
|
||||
|
||||
def write_results(self, idx, results, batch):
|
||||
|
Reference in New Issue
Block a user