ultralytics 8.0.72
faster Windows trainings and corrupt cache fix (#1912)
Co-authored-by: andreaswimmer <53872150+andreaswimmer@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -162,7 +162,18 @@ def check_source(source):
|
||||
|
||||
def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1, stride=32, auto=True):
|
||||
"""
|
||||
TODO: docs
|
||||
Loads an inference source for object detection and applies necessary transformations.
|
||||
|
||||
Args:
|
||||
source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference.
|
||||
transforms (callable, optional): Custom transformations to be applied to the input source.
|
||||
imgsz (int, optional): The size of the image for inference. Default is 640.
|
||||
vid_stride (int, optional): The frame interval for video sources. Default is 1.
|
||||
stride (int, optional): The model stride. Default is 32.
|
||||
auto (bool, optional): Automatically apply pre-processing. Default is True.
|
||||
|
||||
Returns:
|
||||
dataset: A dataset object for the specified input source.
|
||||
"""
|
||||
source, webcam, screenshot, from_img, in_memory, tensor = check_source(source)
|
||||
source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img, tensor)
|
||||
@ -179,7 +190,6 @@ def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1,
|
||||
auto=auto,
|
||||
transforms=transforms,
|
||||
vid_stride=vid_stride)
|
||||
|
||||
elif screenshot:
|
||||
dataset = LoadScreenshots(source, imgsz=imgsz, stride=stride, auto=auto, transforms=transforms)
|
||||
elif from_img:
|
||||
@ -192,6 +202,7 @@ def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1,
|
||||
transforms=transforms,
|
||||
vid_stride=vid_stride)
|
||||
|
||||
setattr(dataset, 'source_type', source_type) # attach source types
|
||||
# Attach source types to the dataset
|
||||
setattr(dataset, 'source_type', source_type)
|
||||
|
||||
return dataset
|
||||
|
@ -77,7 +77,6 @@ class YOLODataset(BaseDataset):
|
||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||
desc = f'{self.prefix}Scanning {path.parent / path.stem}...'
|
||||
total = len(self.im_files)
|
||||
nc = len(self.data['names'])
|
||||
nkpt, ndim = self.data.get('kpt_shape', (0, 0))
|
||||
if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
|
||||
raise ValueError("'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
|
||||
|
Reference in New Issue
Block a user