`ultralytics 8.0.76` minor fixes and improvements (#2004)

Co-authored-by: Seungtaek Kim <seungtaek.kim.94@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
Co-authored-by: Ercalvez <45692523+Ercalvez@users.noreply.github.com>
Co-authored-by: Erwan CALVEZ <ecalvez@enib.fr>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 48c4483795
commit 4916014af2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -76,11 +76,11 @@ jobs:
run: |
python -m pip install --upgrade pip wheel
if [ "${{ matrix.os }}" == "macos-latest" ]; then
pip install -e . coremltools openvino-dev tensorflow-macos tensorflowjs --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e '.[export-macos]' --extra-index-url https://download.pytorch.org/whl/cpu
else
pip install -e . coremltools openvino-dev tensorflow-cpu tensorflowjs --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e '.[export-cpu]' --extra-index-url https://download.pytorch.org/whl/cpu
fi
yolo export format=tflite
yolo export format=tflite imgsz=32
- name: Check environment
run: |
echo "RUNNER_OS is ${{ runner.os }}"

@ -30,7 +30,7 @@ ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /u
# Install pip packages
RUN python3 -m pip install --upgrade pip wheel
RUN pip install --no-cache '.[export]' albumentations comet gsutil notebook
RUN pip install --no-cache . albumentations comet gsutil notebook
# Set environment variables
ENV OMP_NUM_THREADS=1

@ -26,7 +26,7 @@ ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /u
# Install pip packages
RUN python3 -m pip install --upgrade pip wheel
RUN pip install --no-cache '.[export]' albumentations gsutil notebook \
RUN pip install --no-cache . albumentations gsutil notebook \
--extra-index-url https://download.pytorch.org/whl/cpu
# Cleanup

@ -139,7 +139,7 @@ predicts the classes and locations of objects in the input images or videos.
results = model.predict(source=0, stream=True)
for result in results:
# detection
# Detection
result.boxes.xyxy # box with xyxy format, (N, 4)
result.boxes.xywh # box with xywh format, (N, 4)
result.boxes.xyxyn # box with xyxy format but normalized, (N, 4)
@ -147,12 +147,12 @@ predicts the classes and locations of objects in the input images or videos.
result.boxes.conf # confidence score, (N, 1)
result.boxes.cls # cls, (N, 1)
# segmentation
result.masks.masks # masks, (N, H, W)
# Segmentation
result.masks.data # masks, (N, H, W)
result.masks.xy # x,y segments (pixels), List[segment] * N
result.masks.xyn # x,y segments (normalized), List[segment] * N
# classification
# Classification
result.probs # cls prob, (num_class, )
# Each result is composed of torch.Tensor by default,

@ -39,8 +39,9 @@ setup(
install_requires=REQUIREMENTS + PKG_REQUIREMENTS,
extras_require={
'dev': ['check-manifest', 'pytest', 'pytest-cov', 'coverage', 'mkdocs-material', 'mkdocstrings[python]'],
'export': ['coremltools>=6.0', 'onnx', 'onnxsim', 'onnxruntime', 'openvino-dev>=2022.3'],
'tf': ['onnx2tf', 'sng4onnx', 'tflite_support', 'tensorflow']},
'export': ['coremltools>=6.0', 'openvino-dev>=2022.3', 'tensorflow', 'tensorflowjs'],
'export-cpu': ['coremltools>=6.0', 'openvino-dev>=2022.3', 'tensorflow-cpu', 'tensorflowjs'],
'export-macos': ['coremltools>=6.0', 'openvino-dev>=2022.3', 'tensorflow-macos', 'tensorflowjs']},
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',

@ -217,7 +217,7 @@ def test_result():
res[0].plot(conf=True, boxes=False, masks=True)
res[0].plot(pil=True)
res[0] = res[0].cpu().numpy()
print(res[0].path, res[0].masks.masks)
print(res[0].path, res[0].masks.data)
model = YOLO('yolov8n.pt')
res = model(SOURCE)

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = '8.0.75'
__version__ = '8.0.76'
from ultralytics.hub import start
from ultralytics.yolo.engine.model import YOLO

@ -8,6 +8,7 @@ import requests
from ultralytics.hub.utils import HUB_API_ROOT, PREFIX, check_dataset_disk_space, smart_request
from ultralytics.yolo.utils import LOGGER, __version__, checks, emojis, is_colab, threaded
from ultralytics.yolo.utils.errors import HUBModelError
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
@ -55,7 +56,8 @@ class HUBTrainingSession:
elif len(url) == 20:
key, model_id = '', url
else:
raise ValueError(f'Invalid HUBTrainingSession input: {url}')
raise HUBModelError(f"model='{url}' not found. Check format is correct, i.e. "
f"model='https://hub.ultralytics.com/models/MODEL_ID' and try again.")
# Authorize
auth = Auth(key)

@ -116,7 +116,7 @@ class YOLO:
@staticmethod
def is_hub_model(model):
return any((
model.startswith('https://hub.ultralytics.com/models/'),
model.startswith('https://hub.ultra'), # i.e. https://hub.ultralytics.com/models/MODEL_ID
[len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID

@ -207,7 +207,7 @@ class Results(SimpleClass):
if pred_masks and show_masks:
if img_gpu is None:
img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
img_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.masks.device).permute(
img_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute(
2, 0, 1).flip(0).contiguous() / 255
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in pred_boxes.cls], im_gpu=img_gpu)

@ -632,9 +632,9 @@ def check_amp(model):
def amp_allclose(m, im):
# All close FP32 vs AMP results
a = m(im, device=device, verbose=False)[0].boxes.boxes # FP32 inference
a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference
with torch.cuda.amp.autocast(True):
b = m(im, device=device, verbose=False)[0].boxes.boxes # AMP inference
b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference
del m
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance

@ -0,0 +1,9 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
from ultralytics.yolo.utils import emojis
class HUBModelError(Exception):
def __init__(self, message='Model not found. Please check model URL and try again.'):
super().__init__(emojis(message))

@ -172,19 +172,37 @@ class FocalLoss(nn.Module):
class ConfusionMatrix:
# Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
def __init__(self, nc, conf=0.25, iou_thres=0.45):
self.matrix = np.zeros((nc + 1, nc + 1))
def __init__(self, nc, conf=0.25, iou_thres=0.45, task='detect'):
self.task = task
self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == 'detect' else np.zeros((nc, nc))
self.nc = nc # number of classes
self.conf = conf
self.iou_thres = iou_thres
def process_cls_preds(self, preds, targets):
"""
Update confusion matrix for classification task
Arguments:
preds (Array[N, min(nc,5)])
targets (Array[N, 1])
Returns:
None, updates confusion matrix accordingly
"""
preds, targets = torch.cat(preds)[:, 0], torch.cat(targets)
for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
self.matrix[t][p] += 1
def process_batch(self, detections, labels):
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Arguments:
detections (Array[N, 6]), x1, y1, x2, y2, conf, class
labels (Array[M, 5]), class, x1, y1, x2, y2
Returns:
None, updates confusion matrix accordingly
"""
@ -231,7 +249,7 @@ class ConfusionMatrix:
tp = self.matrix.diagonal() # true positives
fp = self.matrix.sum(1) - tp # false positives
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
return tp[:-1], fp[:-1] # remove background class
return (tp[:-1], fp[:-1]) if self.task == 'detect' else (tp, fp) # remove background class if task=detect
@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
@plt_settings()

@ -547,9 +547,9 @@ def crop_mask(masks, boxes):
(torch.Tensor): The masks are being cropped to the bounding box.
"""
n, h, w = masks.shape
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(1,1,n)
r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,w,1)
c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(h,1,1)
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)
c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))

@ -3,7 +3,7 @@
from ultralytics.yolo.data import build_classification_dataloader
from ultralytics.yolo.engine.validator import BaseValidator
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER
from ultralytics.yolo.utils.metrics import ClassifyMetrics
from ultralytics.yolo.utils.metrics import ClassifyMetrics, ConfusionMatrix
class ClassificationValidator(BaseValidator):
@ -12,11 +12,15 @@ class ClassificationValidator(BaseValidator):
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'classify'
self.metrics = ClassifyMetrics()
self.save_dir = save_dir
def get_desc(self):
return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc')
def init_metrics(self, model):
self.names = model.names
self.nc = len(model.names)
self.confusion_matrix = ConfusionMatrix(nc=self.nc, task='classify')
self.pred = []
self.targets = []
@ -32,8 +36,9 @@ class ClassificationValidator(BaseValidator):
self.targets.append(batch['cls'])
def finalize_metrics(self, *args, **kwargs):
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
self.metrics.speed = self.speed
# self.metrics.confusion_matrix = self.confusion_matrix # TODO: classification ConfusionMatrix
self.metrics.confusion_matrix = self.confusion_matrix
def get_stats(self):
self.metrics.process(self.targets, self.pred)
@ -50,6 +55,8 @@ class ClassificationValidator(BaseValidator):
def print_results(self):
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5))
if self.args.plots:
self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
def val(cfg=DEFAULT_CFG, use_python=False):

Loading…
Cancel
Save