make FSDP compatible with PyTorch 2.4
parent
e1277af2ba
commit
5d31bf88fc
|
@ -63,10 +63,9 @@ def is_sharded_fsdp(x):
|
|||
|
||||
|
||||
def free_if_fsdp(x):
|
||||
if is_sharded_fsdp(x):
|
||||
handles = x._handles
|
||||
true_list = [True for h in handles]
|
||||
_reshard(x, handles, true_list)
|
||||
if is_sharded_fsdp(x) and x._has_params:
|
||||
handle = x._handle
|
||||
_reshard(x, handle, True)
|
||||
|
||||
|
||||
def get_fsdp_modules(x):
|
||||
|
|
|
@ -348,9 +348,12 @@ class SSLMetaArch(nn.Module):
|
|||
def fsdp_synchronize_streams(self):
|
||||
if self.need_to_synchronize_fsdp_streams:
|
||||
torch.cuda.synchronize()
|
||||
self.student.dino_head._streams = (
|
||||
self.teacher.dino_head._streams
|
||||
) = self.student.backbone._streams = self.teacher.backbone._streams
|
||||
for attr in {"_unshard_stream", "_post_backward_stream", "_pre_unshard_stream", "_all_reduce_stream",
|
||||
"_default_stream"}:
|
||||
stream = getattr(self.teacher.backbone, attr)
|
||||
setattr(self.student.dino_head, attr, stream)
|
||||
setattr(self.teacher.dino_head, attr, stream)
|
||||
setattr(self.student.backbone, attr, stream)
|
||||
self.need_to_synchronize_fsdp_streams = False
|
||||
|
||||
def update_teacher(self, m):
|
||||
|
|
Loading…
Reference in New Issue