Merge 1f3073a94f
into e1277af2ba
commit
fb262e9d94
|
@ -226,7 +226,7 @@ class ImageNet(ExtendedVisionDataset):
|
|||
old_percent = percent
|
||||
|
||||
actual_index = index + 1
|
||||
class_index = np.uint32(-1)
|
||||
class_index = np.uint32(100001)
|
||||
class_id, class_name = "", ""
|
||||
entries_array[index] = (actual_index, class_index, class_id, class_name)
|
||||
else:
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any, Callable, List, Optional, TypeVar
|
|||
|
||||
import torch
|
||||
from torch.utils.data import Sampler
|
||||
from torchvision.datasets import ImageFolder
|
||||
|
||||
from .datasets import ImageNet, ImageNet22k
|
||||
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
|
||||
|
@ -58,6 +59,8 @@ def _parse_dataset_str(dataset_str: str):
|
|||
kwargs["split"] = ImageNet.Split[kwargs["split"]]
|
||||
elif name == "ImageNet22k":
|
||||
class_ = ImageNet22k
|
||||
elif name == "ImageFolder":
|
||||
class_ = ImageFolder
|
||||
else:
|
||||
raise ValueError(f'Unsupported dataset "{name}"')
|
||||
|
||||
|
|
|
@ -157,16 +157,16 @@ class _TorchDistributedEnvironment:
|
|||
self.local_rank = -1
|
||||
self.local_world_size = -1
|
||||
|
||||
if _is_slurm_job_process():
|
||||
return self._set_from_slurm_env()
|
||||
|
||||
env_vars = _collect_env_vars()
|
||||
if not env_vars:
|
||||
# Environment is not set
|
||||
pass
|
||||
elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS):
|
||||
if len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS):
|
||||
# Environment is fully set
|
||||
return self._set_from_preset_env()
|
||||
elif _is_slurm_job_process():
|
||||
return self._set_from_slurm_env()
|
||||
elif not env_vars:
|
||||
# Environment is not set
|
||||
pass
|
||||
else:
|
||||
# Environment is partially set
|
||||
collected_env_vars = ", ".join(env_vars.keys())
|
||||
|
|
|
@ -74,6 +74,8 @@ def get_fsdp_modules(x):
|
|||
|
||||
|
||||
def reshard_fsdp_model(x):
|
||||
return
|
||||
|
||||
for m in get_fsdp_modules(x):
|
||||
free_if_fsdp(m)
|
||||
|
||||
|
|
|
@ -348,10 +348,12 @@ class SSLMetaArch(nn.Module):
|
|||
def fsdp_synchronize_streams(self):
|
||||
if self.need_to_synchronize_fsdp_streams:
|
||||
torch.cuda.synchronize()
|
||||
self.need_to_synchronize_fsdp_streams = False
|
||||
return
|
||||
self.student.dino_head._streams = (
|
||||
self.teacher.dino_head._streams
|
||||
) = self.student.backbone._streams = self.teacher.backbone._streams
|
||||
self.need_to_synchronize_fsdp_streams = False
|
||||
|
||||
|
||||
def update_teacher(self, m):
|
||||
student_param_list = []
|
||||
|
|
|
@ -314,5 +314,6 @@ def main(args):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
args = get_args_parser(add_help=True).parse_args()
|
||||
main(args)
|
||||
|
|
Loading…
Reference in New Issue