pull/409/head
Tuatini Godard 2024-04-17 14:31:30 +02:00
parent e1277af2ba
commit b05584f6f8
No known key found for this signature in database
2 changed files with 9 additions and 8 deletions

View File

@ -62,11 +62,10 @@ def is_sharded_fsdp(x):
return is_fsdp(x) and x.sharding_strategy is not ShardingStrategy.NO_SHARD
def free_if_fsdp(x):
if is_sharded_fsdp(x):
handles = x._handles
true_list = [True for h in handles]
_reshard(x, handles, true_list)
def free_if_fsdp(x: FSDP):
if is_sharded_fsdp(x) and x._has_params:
handle = x._handle
_reshard(x, handle, True)
def get_fsdp_modules(x):

View File

@ -348,9 +348,11 @@ class SSLMetaArch(nn.Module):
def fsdp_synchronize_streams(self):
if self.need_to_synchronize_fsdp_streams:
torch.cuda.synchronize()
self.student.dino_head._streams = (
self.teacher.dino_head._streams
) = self.student.backbone._streams = self.teacher.backbone._streams
for attr in {"_unshard_stream", "_post_backward_stream", "_pre_unshard_stream", "_all_reduce_stream", "_default_stream"}:
stream = getattr(self.teacher.backbone, attr)
setattr(self.student.dino_head, attr, stream)
setattr(self.teacher.dino_head, attr, stream)
setattr(self.student.backbone, attr, stream)
self.need_to_synchronize_fsdp_streams = False
def update_teacher(self, m):