Add scaling factor for gradient accumulation in forward_backward method

pull/509/head
ChuaHanChong 2025-03-25 06:06:38 +00:00
parent 235eac76c9
commit 9bdb158362
1 changed files with 4 additions and 1 deletions

View File

@ -129,7 +129,7 @@ class SSLMetaArch(nn.Module):
else:
loss.backward()
def forward_backward(self, images, teacher_temp):
def forward_backward(self, images, teacher_temp, scale=1.0):
n_global_crops = 2
assert n_global_crops == 2
n_local_crops = self.cfg.crops.local_crops_number
@ -339,6 +339,9 @@ class SSLMetaArch(nn.Module):
# accumulate loss
loss_accumulator += self.ibot_loss_weight * ibot_patch_loss
# Apply scaling factor for gradient accumulation
loss_accumulator = loss_accumulator * scale
self.backprop_loss(loss_accumulator)
self.fsdp_synchronize_streams()