Add scaling factor for gradient accumulation in forward_backward method
parent
235eac76c9
commit
9bdb158362
|
@ -129,7 +129,7 @@ class SSLMetaArch(nn.Module):
|
|||
else:
|
||||
loss.backward()
|
||||
|
||||
def forward_backward(self, images, teacher_temp):
|
||||
def forward_backward(self, images, teacher_temp, scale=1.0):
|
||||
n_global_crops = 2
|
||||
assert n_global_crops == 2
|
||||
n_local_crops = self.cfg.crops.local_crops_number
|
||||
|
@ -339,6 +339,9 @@ class SSLMetaArch(nn.Module):
|
|||
# accumulate loss
|
||||
loss_accumulator += self.ibot_loss_weight * ibot_patch_loss
|
||||
|
||||
# Apply scaling factor for gradient accumulation
|
||||
loss_accumulator = loss_accumulator * scale
|
||||
|
||||
self.backprop_loss(loss_accumulator)
|
||||
|
||||
self.fsdp_synchronize_streams()
|
||||
|
|
Loading…
Reference in New Issue