Fix save_txt in track mode and add Keypoints and Probs (#2921)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Laughing
2023-06-01 06:40:10 +08:00
committed by GitHub
parent 6c65934b55
commit f4b34fc30b
2 changed files with 156 additions and 58 deletions

View File

@ -207,42 +207,34 @@ def test_predict_callback_and_setup():
print(boxes)
def test_result():
model = YOLO('yolov8n-pose.pt')
res = model([SOURCE, SOURCE])
res[0].plot(conf=True, boxes=False)
res[0].plot(pil=True)
res[0] = res[0].cpu().numpy()
print(res[0].path, res[0].keypoints)
def _test_results_api(res):
# General apis except plot
res = res.cpu().numpy()
# res = res.cuda()
res = res.to(device='cpu', dtype=torch.float32)
res.save_txt('label.txt', save_conf=False)
res.save_txt('label.txt', save_conf=True)
res.save_crop('crops/')
res.tojson(normalize=False)
res.tojson(normalize=True)
res.plot(pil=True)
res.plot(conf=True, boxes=False)
res.plot()
print(res.path)
for k in res.keys:
print(getattr(res, k).data)
model = YOLO('yolov8n-seg.pt')
res = model([SOURCE, SOURCE])
res[0].plot(conf=True, boxes=False, masks=True)
res[0].plot(pil=True)
res[0] = res[0].cpu().numpy()
print(res[0].path, res[0].masks.data)
model = YOLO('yolov8n.pt')
res = model(SOURCE)
res[0].plot(pil=True)
res[0].plot()
res[0] = res[0].cpu().numpy()
print(res[0].path)
model = YOLO('yolov8n-cls.pt')
res = model(SOURCE)
res[0].plot(probs=False)
res[0].plot(pil=True)
res[0].plot()
res[0] = res[0].cpu().numpy()
print(res[0].path)
def test_results():
for m in ['yolov8n-pose.pt', 'yolov8n-seg.pt', 'yolov8n.pt', 'yolov8n-cls.pt']:
model = YOLO(m)
res = model([SOURCE, SOURCE])
_test_results_api(res[0])
def test_track():
im = cv2.imread(str(SOURCE))
model = YOLO(MODEL)
seg_model = YOLO('yolov8n-seg.pt')
pose_model = YOLO('yolov8n-pose.pt')
model.track(source=im)
seg_model.track(source=im)
pose_model.track(source=im)
for m in ['yolov8n-pose.pt', 'yolov8n-seg.pt', 'yolov8n.pt']:
model = YOLO(m)
res = model.track(source=im)
_test_results_api(res[0])