Add world_size check before setting up DDP train (#3191)

single_channel
Bruno Arine 1 year ago committed by GitHub
parent f8e1dcc43f
commit 0d91d6df6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -217,7 +217,7 @@ class BaseTrainer:
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
self.amp = torch.tensor(check_amp(self.model), device=self.device) self.amp = torch.tensor(check_amp(self.model), device=self.device)
callbacks.default_callbacks = callbacks_backup # restore callbacks callbacks.default_callbacks = callbacks_backup # restore callbacks
if RANK > -1: # DDP if RANK > -1 and world_size > 1: # DDP
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None) dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
self.amp = bool(self.amp) # as boolean self.amp = bool(self.amp) # as boolean
self.scaler = amp.GradScaler(enabled=self.amp) self.scaler = amp.GradScaler(enabled=self.amp)

Loading…
Cancel
Save