make FSDP compatible with PyTorch 2.4

pull/461/head
Adrian Wolny 2024-08-20 18:06:13 +02:00
parent e1277af2ba
commit 5d31bf88fc
2 changed files with 9 additions and 7 deletions

View File

@ -63,10 +63,9 @@ def is_sharded_fsdp(x):
def free_if_fsdp(x):
if is_sharded_fsdp(x):
handles = x._handles
true_list = [True for h in handles]
_reshard(x, handles, true_list)
if is_sharded_fsdp(x) and x._has_params:
handle = x._handle
_reshard(x, handle, True)
def get_fsdp_modules(x):

View File

@ -348,9 +348,12 @@ class SSLMetaArch(nn.Module):
def fsdp_synchronize_streams(self):
if self.need_to_synchronize_fsdp_streams:
torch.cuda.synchronize()
self.student.dino_head._streams = (
self.teacher.dino_head._streams
) = self.student.backbone._streams = self.teacher.backbone._streams
for attr in {"_unshard_stream", "_post_backward_stream", "_pre_unshard_stream", "_all_reduce_stream",
"_default_stream"}:
stream = getattr(self.teacher.backbone, attr)
setattr(self.student.dino_head, attr, stream)
setattr(self.teacher.dino_head, attr, stream)
setattr(self.student.backbone, attr, stream)
self.need_to_synchronize_fsdp_streams = False
def update_teacher(self, m):