make training on custom data work!
parent
e1277af2ba
commit
400983cf83
|
@ -5,3 +5,4 @@
|
|||
|
||||
from .image_net import ImageNet
|
||||
from .image_net_22k import ImageNet22k
|
||||
from .custom_data import CustomData
|
|
@ -0,0 +1,41 @@
|
|||
from enum import Enum
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torchvision.datasets import DatasetFolder
|
||||
|
||||
from .extended import ExtendedVisionDataset
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
def pil_loader(p):
|
||||
return Image.open(p).convert("RGB")
|
||||
|
||||
|
||||
class CustomData(ExtendedVisionDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
root: str,
|
||||
transforms: Optional[Callable] = None,
|
||||
transform: Optional[Callable] = None,
|
||||
target_transform: Optional[Callable] = None,
|
||||
) -> None:
|
||||
super().__init__(root, transforms, transform, target_transform)
|
||||
self.root = root
|
||||
self.data = DatasetFolder(root, loader=pil_loader, extensions=["jpg"])
|
||||
|
||||
def get_image_data(self, index: int) -> bytes:
|
||||
return self.data[index][0]
|
||||
|
||||
def get_target(self, index: int) -> Optional[int]:
|
||||
return 0
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
|
@ -22,8 +22,9 @@ class ExtendedVisionDataset(VisionDataset):
|
|||
|
||||
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||||
try:
|
||||
image_data = self.get_image_data(index)
|
||||
image = ImageDataDecoder(image_data).decode()
|
||||
# image_data = self.get_image_data(index)
|
||||
# image = ImageDataDecoder(image_data).decode()
|
||||
image = self.get_image_data(index)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"can not read image for sample {index}") from e
|
||||
target = self.get_target(index)
|
||||
|
|
|
@ -10,7 +10,7 @@ from typing import Any, Callable, List, Optional, TypeVar
|
|||
import torch
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
from .datasets import ImageNet, ImageNet22k
|
||||
from .datasets import ImageNet, ImageNet22k, CustomData
|
||||
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
|
||||
|
||||
|
||||
|
@ -58,6 +58,8 @@ def _parse_dataset_str(dataset_str: str):
|
|||
kwargs["split"] = ImageNet.Split[kwargs["split"]]
|
||||
elif name == "ImageNet22k":
|
||||
class_ = ImageNet22k
|
||||
elif name == "CustomData":
|
||||
class_ = CustomData
|
||||
else:
|
||||
raise ValueError(f'Unsupported dataset "{name}"')
|
||||
|
||||
|
|
|
@ -63,11 +63,14 @@ def is_sharded_fsdp(x):
|
|||
|
||||
|
||||
def free_if_fsdp(x):
|
||||
if is_sharded_fsdp(x):
|
||||
handles = x._handles
|
||||
true_list = [True for h in handles]
|
||||
_reshard(x, handles, true_list)
|
||||
|
||||
# if is_sharded_fsdp(x): # works for torch 2.0.0
|
||||
# handles = x._handles
|
||||
# true_list = [True for h in handles]
|
||||
# _reshard(x, handles, true_list)
|
||||
# make changes according to https://github.com/facebookresearch/dinov2/pull/281/files
|
||||
if is_sharded_fsdp(x) and x._has_params:
|
||||
handle = x._handle
|
||||
_reshard(x, handle, True)
|
||||
|
||||
def get_fsdp_modules(x):
|
||||
return FSDP.fsdp_modules(x)
|
||||
|
|
|
@ -172,7 +172,8 @@ def get_attn_bias_and_cat(x_list, branges=None):
|
|||
for b, x in zip(batch_sizes, x_list):
|
||||
for _ in range(b):
|
||||
seqlens.append(x.shape[1])
|
||||
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
||||
# attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
||||
attn_bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens(seqlens) # works for newer verions of xformers
|
||||
attn_bias._batch_sizes = batch_sizes
|
||||
attn_bias_cache[all_shapes] = attn_bias
|
||||
|
||||
|
|
|
@ -61,10 +61,16 @@ def get_args_parser(
|
|||
help="Partition where to submit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-volta32",
|
||||
action="store_true",
|
||||
help="Request V100-32GB GPUs",
|
||||
"--mem-per-gpu",
|
||||
default="60G",
|
||||
type=str,
|
||||
help="Memory per GPU",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--use-volta32",
|
||||
# action="store_true",
|
||||
# help="Request V100-32GB GPUs",
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--comment",
|
||||
default="",
|
||||
|
@ -97,8 +103,8 @@ def submit_jobs(task_class, args, name: str):
|
|||
executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30)
|
||||
|
||||
kwargs = {}
|
||||
if args.use_volta32:
|
||||
kwargs["slurm_constraint"] = "volta32gb"
|
||||
# if args.use_volta32:
|
||||
# kwargs["slurm_constraint"] = "volta32gb"
|
||||
if args.comment:
|
||||
kwargs["slurm_comment"] = args.comment
|
||||
if args.exclude:
|
||||
|
@ -110,10 +116,11 @@ def submit_jobs(task_class, args, name: str):
|
|||
timeout_min=args.timeout, # max is 60 * 72
|
||||
slurm_signal_delay_s=120,
|
||||
slurm_partition=args.partition,
|
||||
mem_per_gpu=args.mem_per_gpu,
|
||||
**kwargs,
|
||||
)
|
||||
executor.update_parameters(name=name, **executor_params)
|
||||
|
||||
print(args, executor_params)
|
||||
task = task_class(args)
|
||||
job = executor.submit(task)
|
||||
|
||||
|
|
|
@ -262,7 +262,8 @@ class SSLMetaArch(nn.Module):
|
|||
]
|
||||
|
||||
# 2: run
|
||||
_attn_bias, cat_inputs = fmha.BlockDiagonalMask.from_tensor_list(inputs_for_student_head_list)
|
||||
# _attn_bias, cat_inputs = fmha.BlockDiagonalMask.from_tensor_list(inputs_for_student_head_list)
|
||||
_attn_bias, cat_inputs = fmha.attn_bias.BlockDiagonalMask.from_tensor_list(inputs_for_student_head_list)
|
||||
outputs_list = _attn_bias.split(self.student.dino_head(cat_inputs))
|
||||
|
||||
# 3a: local crops cls tokens
|
||||
|
@ -348,9 +349,15 @@ class SSLMetaArch(nn.Module):
|
|||
def fsdp_synchronize_streams(self):
|
||||
if self.need_to_synchronize_fsdp_streams:
|
||||
torch.cuda.synchronize()
|
||||
self.student.dino_head._streams = (
|
||||
self.teacher.dino_head._streams
|
||||
) = self.student.backbone._streams = self.teacher.backbone._streams
|
||||
# self.student.dino_head._streams = (
|
||||
# self.teacher.dino_head._streams
|
||||
# ) = self.student.backbone._streams = self.teacher.backbone._streams
|
||||
# make changes according to https://github.com/facebookresearch/dinov2/pull/281/files
|
||||
for attr in {"_unshard_stream", "_post_backward_stream", "_pre_unshard_stream", "_all_reduce_stream", "_default_stream"}:
|
||||
stream = getattr(self.teacher.backbone, attr)
|
||||
setattr(self.student.dino_head, attr, stream)
|
||||
setattr(self.teacher.dino_head, attr, stream)
|
||||
setattr(self.student.backbone, attr, stream)
|
||||
self.need_to_synchronize_fsdp_streams = False
|
||||
|
||||
def update_teacher(self, m):
|
||||
|
|
|
@ -76,20 +76,20 @@ def get_slurm_executor_parameters(
|
|||
) -> Dict[str, Any]:
|
||||
# create default parameters
|
||||
params = {
|
||||
"mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
|
||||
# "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
|
||||
"gpus_per_node": num_gpus_per_node,
|
||||
"tasks_per_node": num_gpus_per_node, # one task per GPU
|
||||
"cpus_per_task": 10,
|
||||
"cpus_per_gpu": 7,
|
||||
"nodes": nodes,
|
||||
"slurm_partition": get_slurm_partition(cluster_type),
|
||||
}
|
||||
# apply cluster-specific adjustments
|
||||
cluster_type = get_cluster_type(cluster_type)
|
||||
if cluster_type == ClusterType.AWS:
|
||||
params["cpus_per_task"] = 12
|
||||
params["cpus_per_gpu"] = 12
|
||||
del params["mem_gb"]
|
||||
elif cluster_type == ClusterType.RSC:
|
||||
params["cpus_per_task"] = 12
|
||||
params["cpus_per_gpu"] = 12
|
||||
# set additional parameters / apply overrides
|
||||
params.update(kwargs)
|
||||
return params
|
||||
|
|
Loading…
Reference in New Issue