ultralytics 8.0.57
Comet, AMP, Classify, Docker updates (#1601)
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -142,6 +142,7 @@ class YOLO:
|
||||
self.task = task or guess_model_task(weights)
|
||||
self.ckpt_path = weights
|
||||
self.overrides['model'] = weights
|
||||
self.overrides['task'] = self.task
|
||||
|
||||
def _check_is_pytorch_model(self):
|
||||
"""
|
||||
|
@ -203,8 +203,8 @@ class BaseTrainer:
|
||||
self.model = self.model.to(self.device)
|
||||
self.set_model_attributes()
|
||||
# Check AMP
|
||||
self.amp = torch.tensor(True).to(self.device)
|
||||
if RANK in (-1, 0): # Single-GPU and DDP
|
||||
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
|
||||
if self.amp and RANK in (-1, 0): # Single-GPU and DDP
|
||||
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
|
||||
self.amp = torch.tensor(check_amp(self.model), device=self.device)
|
||||
callbacks.default_callbacks = callbacks_backup # restore callbacks
|
||||
|
Reference in New Issue
Block a user