_handle and _streams fix
parent
0e1c5b4e35
commit
087a4cea01
|
@ -0,0 +1,9 @@
|
|||
train:
|
||||
dataset_path: /home/manon/classification/data/Single_cells/medhi
|
||||
output_dir: /home/guevel/OT4D/cell_similarity/logs
|
||||
batch_size_per_gpu: 32
|
||||
OFFICIAL_EPOCH_LENGTH: 200
|
||||
student:
|
||||
arch: vit_large
|
||||
block_chunks: 4
|
||||
|
|
@ -4,7 +4,7 @@
|
|||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .adapters import DatasetWithEnumeratedTargets
|
||||
from .loaders import make_data_loader, make_dataset, SamplerType
|
||||
from .loaders import make_data_loader, make_dataset, make_custom_dataset, SamplerType
|
||||
from .collate import collate_data_and_cast
|
||||
from .masking import MaskingGenerator
|
||||
from .augmentations import DataAugmentationDINO
|
||||
|
|
|
@ -64,9 +64,8 @@ def is_sharded_fsdp(x):
|
|||
|
||||
def free_if_fsdp(x):
|
||||
if is_sharded_fsdp(x):
|
||||
handles = x._handle
|
||||
true_list = [True for h in handles]
|
||||
_reshard(x, handles, true_list)
|
||||
handle = x._handle
|
||||
_reshard(x, handle, True)
|
||||
|
||||
|
||||
def get_fsdp_modules(x):
|
||||
|
|
|
@ -24,7 +24,7 @@ try:
|
|||
except ImportError:
|
||||
raise AssertionError("xFormers is required for training")
|
||||
|
||||
|
||||
TORCH_VERSION = torch.__version__
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
|
@ -348,9 +348,19 @@ class SSLMetaArch(nn.Module):
|
|||
def fsdp_synchronize_streams(self):
|
||||
if self.need_to_synchronize_fsdp_streams:
|
||||
torch.cuda.synchronize()
|
||||
self.student.dino_head._streams = (
|
||||
self.teacher.dino_head._streams
|
||||
) = self.student.backbone._streams = self.teacher.backbone._streams
|
||||
if int(TORCH_VERSION[2]) < 1:
|
||||
self.student.dino_head._streams = (
|
||||
self.teacher.dino_head._streams
|
||||
) = self.student.backbone._streams = self.teacher.backbone._streams
|
||||
|
||||
else:
|
||||
attrs = ["_unshard_stream", "_post_backward_stream", "_pre_unshard_stream", "_all_reduce_stream", "_default_stream"]
|
||||
for attr in attrs:
|
||||
stream = getattr(self.teacher.dino_head, attr)
|
||||
setattr(self.student.dino_head, attr, stream)
|
||||
setattr(self.student.backbone, attr, stream)
|
||||
setattr(self.teacher.backbone, attr, stream)
|
||||
|
||||
self.need_to_synchronize_fsdp_streams = False
|
||||
|
||||
def update_teacher(self, m):
|
||||
|
|
|
@ -191,7 +191,7 @@ def do_train(cfg, model, resume=False):
|
|||
# setup data loader
|
||||
|
||||
dataset = make_custom_dataset(
|
||||
dataset_str=cfg.train.dataset_path,
|
||||
dataset_path=cfg.train.dataset_path,
|
||||
transform=data_transform,
|
||||
)
|
||||
# sampler_type = SamplerType.INFINITE
|
||||
|
@ -297,6 +297,7 @@ def do_train(cfg, model, resume=False):
|
|||
|
||||
def main(args):
|
||||
cfg = setup(args)
|
||||
torch.cuda.memory._record_memory_history()
|
||||
|
||||
model = SSLMetaArch(cfg).to(torch.device("cuda"))
|
||||
model.prepare_for_distributed_training()
|
||||
|
@ -312,6 +313,7 @@ def main(args):
|
|||
return do_test(cfg, model, f"manual_{iteration}")
|
||||
|
||||
do_train(cfg, model, resume=not args.no_resume)
|
||||
torch.cuda.memory._dump_snapshot(os.path.join(cfg.train.output_dir, "memory_snapshot.pickle"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue