ultralytics 8.0.80
single-line docstring fixes (#2060)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -73,7 +73,7 @@ ARM64 = platform.machine() in ('arm64', 'aarch64')
|
||||
|
||||
|
||||
def export_formats():
|
||||
"""YOLOv8 export formats"""
|
||||
"""YOLOv8 export formats."""
|
||||
import pandas
|
||||
x = [
|
||||
['PyTorch', '-', '.pt', True, True],
|
||||
@ -92,7 +92,7 @@ def export_formats():
|
||||
|
||||
|
||||
def gd_outputs(gd):
|
||||
"""TensorFlow GraphDef model output node names"""
|
||||
"""TensorFlow GraphDef model output node names."""
|
||||
name_list, input_list = [], []
|
||||
for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
|
||||
name_list.append(node.name)
|
||||
@ -101,7 +101,7 @@ def gd_outputs(gd):
|
||||
|
||||
|
||||
def try_export(inner_func):
|
||||
"""YOLOv8 export decorator, i..e @try_export"""
|
||||
"""YOLOv8 export decorator, i..e @try_export."""
|
||||
inner_args = get_default_args(inner_func)
|
||||
|
||||
def outer_func(*args, **kwargs):
|
||||
@ -119,7 +119,7 @@ def try_export(inner_func):
|
||||
|
||||
|
||||
class iOSDetectModel(torch.nn.Module):
|
||||
"""Wrap an Ultralytics YOLO model for iOS export"""
|
||||
"""Wrap an Ultralytics YOLO model for iOS export."""
|
||||
|
||||
def __init__(self, model, im):
|
||||
super().__init__()
|
||||
@ -246,28 +246,28 @@ class Exporter:
|
||||
# Exports
|
||||
f = [''] * len(fmts) # exported filenames
|
||||
if jit: # TorchScript
|
||||
f[0], _ = self._export_torchscript()
|
||||
f[0], _ = self.export_torchscript()
|
||||
if engine: # TensorRT required before ONNX
|
||||
f[1], _ = self._export_engine()
|
||||
f[1], _ = self.export_engine()
|
||||
if onnx or xml: # OpenVINO requires ONNX
|
||||
f[2], _ = self._export_onnx()
|
||||
f[2], _ = self.export_onnx()
|
||||
if xml: # OpenVINO
|
||||
f[3], _ = self._export_openvino()
|
||||
f[3], _ = self.export_openvino()
|
||||
if coreml: # CoreML
|
||||
f[4], _ = self._export_coreml()
|
||||
f[4], _ = self.export_coreml()
|
||||
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
|
||||
self.args.int8 |= edgetpu
|
||||
f[5], s_model = self._export_saved_model()
|
||||
f[5], s_model = self.export_saved_model()
|
||||
if pb or tfjs: # pb prerequisite to tfjs
|
||||
f[6], _ = self._export_pb(s_model)
|
||||
f[6], _ = self.export_pb(s_model)
|
||||
if tflite:
|
||||
f[7], _ = self._export_tflite(s_model, nms=False, agnostic_nms=self.args.agnostic_nms)
|
||||
f[7], _ = self.export_tflite(s_model, nms=False, agnostic_nms=self.args.agnostic_nms)
|
||||
if edgetpu:
|
||||
f[8], _ = self._export_edgetpu(tflite_model=Path(f[5]) / f'{self.file.stem}_full_integer_quant.tflite')
|
||||
f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f'{self.file.stem}_full_integer_quant.tflite')
|
||||
if tfjs:
|
||||
f[9], _ = self._export_tfjs()
|
||||
f[9], _ = self.export_tfjs()
|
||||
if paddle: # PaddlePaddle
|
||||
f[10], _ = self._export_paddle()
|
||||
f[10], _ = self.export_paddle()
|
||||
|
||||
# Finish
|
||||
f = [str(x) for x in f if x] # filter out '' and None
|
||||
@ -289,8 +289,8 @@ class Exporter:
|
||||
return f # return list of exported files/dirs
|
||||
|
||||
@try_export
|
||||
def _export_torchscript(self, prefix=colorstr('TorchScript:')):
|
||||
# YOLOv8 TorchScript model export
|
||||
def export_torchscript(self, prefix=colorstr('TorchScript:')):
|
||||
"""YOLOv8 TorchScript model export."""
|
||||
LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
|
||||
f = self.file.with_suffix('.torchscript')
|
||||
|
||||
@ -305,8 +305,8 @@ class Exporter:
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def _export_onnx(self, prefix=colorstr('ONNX:')):
|
||||
# YOLOv8 ONNX export
|
||||
def export_onnx(self, prefix=colorstr('ONNX:')):
|
||||
"""YOLOv8 ONNX export."""
|
||||
requirements = ['onnx>=1.12.0']
|
||||
if self.args.simplify:
|
||||
requirements += ['onnxsim>=0.4.17', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime']
|
||||
@ -363,8 +363,8 @@ class Exporter:
|
||||
return f, model_onnx
|
||||
|
||||
@try_export
|
||||
def _export_openvino(self, prefix=colorstr('OpenVINO:')):
|
||||
# YOLOv8 OpenVINO export
|
||||
def export_openvino(self, prefix=colorstr('OpenVINO:')):
|
||||
"""YOLOv8 OpenVINO export."""
|
||||
check_requirements('openvino-dev>=2022.3') # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
||||
import openvino.runtime as ov # noqa
|
||||
from openvino.tools import mo # noqa
|
||||
@ -383,8 +383,8 @@ class Exporter:
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def _export_paddle(self, prefix=colorstr('PaddlePaddle:')):
|
||||
# YOLOv8 Paddle export
|
||||
def export_paddle(self, prefix=colorstr('PaddlePaddle:')):
|
||||
"""YOLOv8 Paddle export."""
|
||||
check_requirements(('paddlepaddle', 'x2paddle'))
|
||||
import x2paddle # noqa
|
||||
from x2paddle.convert import pytorch2paddle # noqa
|
||||
@ -397,8 +397,8 @@ class Exporter:
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def _export_coreml(self, prefix=colorstr('CoreML:')):
|
||||
# YOLOv8 CoreML export
|
||||
def export_coreml(self, prefix=colorstr('CoreML:')):
|
||||
"""YOLOv8 CoreML export."""
|
||||
check_requirements('coremltools>=6.0')
|
||||
import coremltools as ct # noqa
|
||||
|
||||
@ -439,8 +439,8 @@ class Exporter:
|
||||
return f, ct_model
|
||||
|
||||
@try_export
|
||||
def _export_engine(self, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
|
||||
# YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt
|
||||
def export_engine(self, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
|
||||
"""YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt."""
|
||||
assert self.im.device.type != 'cpu', "export running on CPU but must be on GPU, i.e. use 'device=0'"
|
||||
try:
|
||||
import tensorrt as trt # noqa
|
||||
@ -451,7 +451,7 @@ class Exporter:
|
||||
|
||||
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=8.0.0
|
||||
self.args.simplify = True
|
||||
f_onnx, _ = self._export_onnx()
|
||||
f_onnx, _ = self.export_onnx()
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
|
||||
assert Path(f_onnx).exists(), f'failed to export ONNX file: {f_onnx}'
|
||||
@ -504,9 +504,8 @@ class Exporter:
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def _export_saved_model(self, prefix=colorstr('TensorFlow SavedModel:')):
|
||||
|
||||
# YOLOv8 TensorFlow SavedModel export
|
||||
def export_saved_model(self, prefix=colorstr('TensorFlow SavedModel:')):
|
||||
"""YOLOv8 TensorFlow SavedModel export."""
|
||||
try:
|
||||
import tensorflow as tf # noqa
|
||||
except ImportError:
|
||||
@ -525,7 +524,7 @@ class Exporter:
|
||||
|
||||
# Export to ONNX
|
||||
self.args.simplify = True
|
||||
f_onnx, _ = self._export_onnx()
|
||||
f_onnx, _ = self.export_onnx()
|
||||
|
||||
# Export to TF
|
||||
int8 = '-oiqt -qt per-tensor' if self.args.int8 else ''
|
||||
@ -551,8 +550,8 @@ class Exporter:
|
||||
return str(f), keras_model
|
||||
|
||||
@try_export
|
||||
def _export_pb(self, keras_model, prefix=colorstr('TensorFlow GraphDef:')):
|
||||
# YOLOv8 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
|
||||
def export_pb(self, keras_model, prefix=colorstr('TensorFlow GraphDef:')):
|
||||
"""YOLOv8 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow."""
|
||||
import tensorflow as tf # noqa
|
||||
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
|
||||
|
||||
@ -567,8 +566,8 @@ class Exporter:
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def _export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
|
||||
# YOLOv8 TensorFlow Lite export
|
||||
def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
|
||||
"""YOLOv8 TensorFlow Lite export."""
|
||||
import tensorflow as tf # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||
@ -581,44 +580,9 @@ class Exporter:
|
||||
f = saved_model / f'{self.file.stem}_float32.tflite'
|
||||
return str(f), None
|
||||
|
||||
# # OLD TFLITE EXPORT CODE BELOW -------------------------------------------------------------------------------
|
||||
# batch_size, ch, *imgsz = list(self.im.shape) # BCHW
|
||||
# f = str(self.file).replace(self.file.suffix, '-fp16.tflite')
|
||||
#
|
||||
# converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
|
||||
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
|
||||
# converter.target_spec.supported_types = [tf.float16]
|
||||
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
# if self.args.int8:
|
||||
#
|
||||
# def representative_dataset_gen(dataset, n_images=100):
|
||||
# # Dataset generator for use with converter.representative_dataset, returns a generator of np arrays
|
||||
# for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):
|
||||
# im = np.transpose(img, [1, 2, 0])
|
||||
# im = np.expand_dims(im, axis=0).astype(np.float32)
|
||||
# im /= 255
|
||||
# yield [im]
|
||||
# if n >= n_images:
|
||||
# break
|
||||
#
|
||||
# dataset = LoadImages(check_det_dataset(self.args.data)['train'], imgsz=imgsz, auto=False)
|
||||
# converter.representative_dataset = lambda: representative_dataset_gen(dataset, n_images=100)
|
||||
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
||||
# converter.target_spec.supported_types = []
|
||||
# converter.inference_input_type = tf.uint8 # or tf.int8
|
||||
# converter.inference_output_type = tf.uint8 # or tf.int8
|
||||
# converter.experimental_new_quantizer = True
|
||||
# f = str(self.file).replace(self.file.suffix, '-int8.tflite')
|
||||
# if nms or agnostic_nms:
|
||||
# converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
|
||||
#
|
||||
# tflite_model = converter.convert()
|
||||
# open(f, 'wb').write(tflite_model)
|
||||
# return f, None
|
||||
|
||||
@try_export
|
||||
def _export_edgetpu(self, tflite_model='', prefix=colorstr('Edge TPU:')):
|
||||
# YOLOv8 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
|
||||
def export_edgetpu(self, tflite_model='', prefix=colorstr('Edge TPU:')):
|
||||
"""YOLOv8 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/."""
|
||||
LOGGER.warning(f'{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185')
|
||||
|
||||
cmd = 'edgetpu_compiler --version'
|
||||
@ -644,8 +608,8 @@ class Exporter:
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def _export_tfjs(self, prefix=colorstr('TensorFlow.js:')):
|
||||
# YOLOv8 TensorFlow.js export
|
||||
def export_tfjs(self, prefix=colorstr('TensorFlow.js:')):
|
||||
"""YOLOv8 TensorFlow.js export."""
|
||||
check_requirements('tensorflowjs')
|
||||
import tensorflow as tf
|
||||
import tensorflowjs as tfjs # noqa
|
||||
@ -681,7 +645,7 @@ class Exporter:
|
||||
return f, None
|
||||
|
||||
def _add_tflite_metadata(self, file):
|
||||
# Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata
|
||||
"""Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata."""
|
||||
from tflite_support import flatbuffers # noqa
|
||||
from tflite_support import metadata as _metadata # noqa
|
||||
from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
|
||||
|
@ -35,6 +35,7 @@ class YOLO:
|
||||
|
||||
Args:
|
||||
model (str, Path): Path to the model file to load or create.
|
||||
task (Any, optional): Task type for the YOLO model. Defaults to None.
|
||||
|
||||
Attributes:
|
||||
predictor (Any): The predictor object.
|
||||
@ -76,7 +77,6 @@ class YOLO:
|
||||
Args:
|
||||
model (Union[str, Path], optional): Path or name of the model to load or create. Defaults to 'yolov8n.pt'.
|
||||
task (Any, optional): Task type for the YOLO model. Defaults to None.
|
||||
|
||||
"""
|
||||
self.callbacks = callbacks.get_default_callbacks()
|
||||
self.predictor = None # reuse predictor
|
||||
@ -273,7 +273,7 @@ class YOLO:
|
||||
@smart_inference_mode()
|
||||
def val(self, data=None, **kwargs):
|
||||
"""
|
||||
Validate a model on a given dataset .
|
||||
Validate a model on a given dataset.
|
||||
|
||||
Args:
|
||||
data (str): The dataset to validate on. Accepts all formats accepted by yolo
|
||||
@ -365,7 +365,7 @@ class YOLO:
|
||||
self.model = self.trainer.model
|
||||
self.trainer.hub_session = self.session # attach optional HUB session
|
||||
self.trainer.train()
|
||||
# update model and cfg after training
|
||||
# Update model and cfg after training
|
||||
if RANK in (-1, 0):
|
||||
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
|
||||
self.overrides = self.model.args
|
||||
@ -467,7 +467,7 @@ class YOLO:
|
||||
@property
|
||||
def names(self):
|
||||
"""
|
||||
Returns class names of the loaded model.
|
||||
Returns class names of the loaded model.
|
||||
"""
|
||||
return self.model.names if hasattr(self.model, 'names') else None
|
||||
|
||||
@ -481,7 +481,7 @@ class YOLO:
|
||||
@property
|
||||
def transforms(self):
|
||||
"""
|
||||
Returns transform of the loaded model.
|
||||
Returns transform of the loaded model.
|
||||
"""
|
||||
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
||||
|
||||
|
@ -134,7 +134,7 @@ class BasePredictor:
|
||||
if not self.args.retina_masks:
|
||||
plot_args['im_gpu'] = im[idx]
|
||||
self.plotted_img = result.plot(**plot_args)
|
||||
# write
|
||||
# Write
|
||||
if self.args.save_txt:
|
||||
result.save_txt(f'{self.txt_path}.txt', save_conf=self.args.save_conf)
|
||||
if self.args.save_crop:
|
||||
@ -153,7 +153,7 @@ class BasePredictor:
|
||||
return list(self.stream_inference(source, model)) # merge list of Result into one
|
||||
|
||||
def predict_cli(self, source=None, model=None):
|
||||
# Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode
|
||||
"""Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode."""
|
||||
gen = self.stream_inference(source, model)
|
||||
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
|
||||
pass
|
||||
@ -182,16 +182,16 @@ class BasePredictor:
|
||||
if self.args.verbose:
|
||||
LOGGER.info('')
|
||||
|
||||
# setup model
|
||||
# Setup model
|
||||
if not self.model:
|
||||
self.setup_model(model)
|
||||
# setup source every time predict is called
|
||||
# Setup source every time predict is called
|
||||
self.setup_source(source if source is not None else self.args.source)
|
||||
|
||||
# check if save_dir/ label file exists
|
||||
# Check if save_dir/ label file exists
|
||||
if self.args.save or self.args.save_txt:
|
||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
# warmup model
|
||||
# Warmup model
|
||||
if not self.done_warmup:
|
||||
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
|
||||
self.done_warmup = True
|
||||
@ -204,22 +204,22 @@ class BasePredictor:
|
||||
path, im, im0s, vid_cap, s = batch
|
||||
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
|
||||
|
||||
# preprocess
|
||||
# Preprocess
|
||||
with self.dt[0]:
|
||||
im = self.preprocess(im)
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
|
||||
# inference
|
||||
# Inference
|
||||
with self.dt[1]:
|
||||
preds = self.model(im, augment=self.args.augment, visualize=visualize)
|
||||
|
||||
# postprocess
|
||||
# Postprocess
|
||||
with self.dt[2]:
|
||||
self.results = self.postprocess(preds, im, im0s)
|
||||
self.run_callbacks('on_predict_postprocess_end')
|
||||
|
||||
# visualize, save, write results
|
||||
# Visualize, save, write results
|
||||
n = len(im)
|
||||
for i in range(n):
|
||||
self.results[i].speed = {
|
||||
@ -288,7 +288,7 @@ class BasePredictor:
|
||||
|
||||
def save_preds(self, vid_cap, idx, save_path):
|
||||
im0 = self.plotted_img
|
||||
# save imgs
|
||||
# Save imgs
|
||||
if self.dataset.mode == 'image':
|
||||
cv2.imwrite(save_path, im0)
|
||||
else: # 'video' or 'stream'
|
||||
|
@ -262,12 +262,12 @@ class Results(SimpleClass):
|
||||
kpts = self.keypoints
|
||||
texts = []
|
||||
if probs is not None:
|
||||
# classify
|
||||
# Classify
|
||||
n5 = min(len(self.names), 5)
|
||||
top5i = probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices
|
||||
[texts.append(f'{probs[j]:.2f} {self.names[j]}') for j in top5i]
|
||||
elif boxes:
|
||||
# detect/segment/pose
|
||||
# Detect/segment/pose
|
||||
for j, d in enumerate(boxes):
|
||||
c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item())
|
||||
line = (c, *d.xywhn.view(-1))
|
||||
@ -418,7 +418,7 @@ class Masks(BaseTensor):
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def segments(self):
|
||||
# Segments-deprecated (normalized)
|
||||
"""Segments-deprecated (normalized)."""
|
||||
LOGGER.warning("WARNING ⚠️ 'Masks.segments' is deprecated. Use 'Masks.xyn' for segments (normalized) and "
|
||||
"'Masks.xy' for segments (pixels) instead.")
|
||||
return self.xyn
|
||||
@ -426,7 +426,7 @@ class Masks(BaseTensor):
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def xyn(self):
|
||||
# Segments (normalized)
|
||||
"""Segments (normalized)."""
|
||||
return [
|
||||
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True)
|
||||
for x in ops.masks2segments(self.data)]
|
||||
@ -434,7 +434,7 @@ class Masks(BaseTensor):
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def xy(self):
|
||||
# Segments (pixels)
|
||||
"""Segments (pixels)."""
|
||||
return [
|
||||
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False)
|
||||
for x in ops.masks2segments(self.data)]
|
||||
|
@ -163,7 +163,7 @@ class BaseTrainer:
|
||||
callback(self)
|
||||
|
||||
def train(self):
|
||||
# Allow device='', device=None on Multi-GPU systems to default to device=0
|
||||
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
|
||||
if isinstance(self.args.device, int) or self.args.device: # i.e. device=0 or device=[0,1,2,3]
|
||||
world_size = torch.cuda.device_count()
|
||||
elif torch.cuda.is_available(): # i.e. device=None or device=''
|
||||
@ -306,7 +306,7 @@ class BaseTrainer:
|
||||
xi = [0, nw] # x interp
|
||||
self.accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())
|
||||
for j, x in enumerate(self.optimizer.param_groups):
|
||||
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
||||
# Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
||||
x['lr'] = np.interp(
|
||||
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(epoch)])
|
||||
if 'momentum' in x:
|
||||
@ -631,7 +631,7 @@ def check_amp(model):
|
||||
return False # AMP only used on CUDA devices
|
||||
|
||||
def amp_allclose(m, im):
|
||||
# All close FP32 vs AMP results
|
||||
"""All close FP32 vs AMP results."""
|
||||
a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference
|
||||
with torch.cuda.amp.autocast(True):
|
||||
b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference
|
||||
|
@ -149,20 +149,20 @@ class BaseValidator:
|
||||
for batch_i, batch in enumerate(bar):
|
||||
self.run_callbacks('on_val_batch_start')
|
||||
self.batch_i = batch_i
|
||||
# preprocess
|
||||
# Preprocess
|
||||
with dt[0]:
|
||||
batch = self.preprocess(batch)
|
||||
|
||||
# inference
|
||||
# Inference
|
||||
with dt[1]:
|
||||
preds = model(batch['img'])
|
||||
|
||||
# loss
|
||||
# Loss
|
||||
with dt[2]:
|
||||
if self.training:
|
||||
self.loss += trainer.criterion(preds, batch)[1]
|
||||
|
||||
# postprocess
|
||||
# Postprocess
|
||||
with dt[3]:
|
||||
preds = self.postprocess(preds)
|
||||
|
||||
|
Reference in New Issue
Block a user