_handle and _streams fix

pull/422/head
Etienne Guevel 2024-05-15 16:01:34 +02:00
parent 0e1c5b4e35
commit 087a4cea01
5 changed files with 29 additions and 9 deletions

View File

@ -0,0 +1,9 @@
train:
dataset_path: /home/manon/classification/data/Single_cells/medhi
output_dir: /home/guevel/OT4D/cell_similarity/logs
batch_size_per_gpu: 32
OFFICIAL_EPOCH_LENGTH: 200
student:
arch: vit_large
block_chunks: 4

View File

@ -4,7 +4,7 @@
# found in the LICENSE file in the root directory of this source tree.
from .adapters import DatasetWithEnumeratedTargets
from .loaders import make_data_loader, make_dataset, SamplerType
from .loaders import make_data_loader, make_dataset, make_custom_dataset, SamplerType
from .collate import collate_data_and_cast
from .masking import MaskingGenerator
from .augmentations import DataAugmentationDINO

View File

@ -64,9 +64,8 @@ def is_sharded_fsdp(x):
def free_if_fsdp(x):
if is_sharded_fsdp(x):
handles = x._handle
true_list = [True for h in handles]
_reshard(x, handles, true_list)
handle = x._handle
_reshard(x, handle, True)
def get_fsdp_modules(x):

View File

@ -24,7 +24,7 @@ try:
except ImportError:
raise AssertionError("xFormers is required for training")
TORCH_VERSION = torch.__version__
logger = logging.getLogger("dinov2")
@ -348,9 +348,19 @@ class SSLMetaArch(nn.Module):
def fsdp_synchronize_streams(self):
if self.need_to_synchronize_fsdp_streams:
torch.cuda.synchronize()
self.student.dino_head._streams = (
self.teacher.dino_head._streams
) = self.student.backbone._streams = self.teacher.backbone._streams
if int(TORCH_VERSION[2]) < 1:
self.student.dino_head._streams = (
self.teacher.dino_head._streams
) = self.student.backbone._streams = self.teacher.backbone._streams
else:
attrs = ["_unshard_stream", "_post_backward_stream", "_pre_unshard_stream", "_all_reduce_stream", "_default_stream"]
for attr in attrs:
stream = getattr(self.teacher.dino_head, attr)
setattr(self.student.dino_head, attr, stream)
setattr(self.student.backbone, attr, stream)
setattr(self.teacher.backbone, attr, stream)
self.need_to_synchronize_fsdp_streams = False
def update_teacher(self, m):

View File

@ -191,7 +191,7 @@ def do_train(cfg, model, resume=False):
# setup data loader
dataset = make_custom_dataset(
dataset_str=cfg.train.dataset_path,
dataset_path=cfg.train.dataset_path,
transform=data_transform,
)
# sampler_type = SamplerType.INFINITE
@ -297,6 +297,7 @@ def do_train(cfg, model, resume=False):
def main(args):
cfg = setup(args)
torch.cuda.memory._record_memory_history()
model = SSLMetaArch(cfg).to(torch.device("cuda"))
model.prepare_for_distributed_training()
@ -312,6 +313,7 @@ def main(args):
return do_test(cfg, model, f"manual_{iteration}")
do_train(cfg, model, resume=not args.no_resume)
torch.cuda.memory._dump_snapshot(os.path.join(cfg.train.output_dir, "memory_snapshot.pickle"))
if __name__ == "__main__":