pull/453/merge
Setepenre 2024-08-07 09:44:54 -04:00 committed by GitHub
commit fb262e9d94
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 16 additions and 8 deletions

View File

@ -226,7 +226,7 @@ class ImageNet(ExtendedVisionDataset):
old_percent = percent
actual_index = index + 1
class_index = np.uint32(-1)
class_index = np.uint32(100001)
class_id, class_name = "", ""
entries_array[index] = (actual_index, class_index, class_id, class_name)
else:

View File

@ -9,6 +9,7 @@ from typing import Any, Callable, List, Optional, TypeVar
import torch
from torch.utils.data import Sampler
from torchvision.datasets import ImageFolder
from .datasets import ImageNet, ImageNet22k
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
@ -58,6 +59,8 @@ def _parse_dataset_str(dataset_str: str):
kwargs["split"] = ImageNet.Split[kwargs["split"]]
elif name == "ImageNet22k":
class_ = ImageNet22k
elif name == "ImageFolder":
class_ = ImageFolder
else:
raise ValueError(f'Unsupported dataset "{name}"')

View File

@ -157,16 +157,16 @@ class _TorchDistributedEnvironment:
self.local_rank = -1
self.local_world_size = -1
if _is_slurm_job_process():
return self._set_from_slurm_env()
env_vars = _collect_env_vars()
if not env_vars:
# Environment is not set
pass
elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS):
if len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS):
# Environment is fully set
return self._set_from_preset_env()
elif _is_slurm_job_process():
return self._set_from_slurm_env()
elif not env_vars:
# Environment is not set
pass
else:
# Environment is partially set
collected_env_vars = ", ".join(env_vars.keys())

View File

@ -74,6 +74,8 @@ def get_fsdp_modules(x):
def reshard_fsdp_model(x):
return
for m in get_fsdp_modules(x):
free_if_fsdp(m)

View File

@ -348,10 +348,12 @@ class SSLMetaArch(nn.Module):
def fsdp_synchronize_streams(self):
if self.need_to_synchronize_fsdp_streams:
torch.cuda.synchronize()
self.need_to_synchronize_fsdp_streams = False
return
self.student.dino_head._streams = (
self.teacher.dino_head._streams
) = self.student.backbone._streams = self.teacher.backbone._streams
self.need_to_synchronize_fsdp_streams = False
def update_teacher(self, m):
student_param_list = []

View File

@ -314,5 +314,6 @@ def main(args):
if __name__ == "__main__":
import sys
args = get_args_parser(add_help=True).parse_args()
main(args)