make training on custom data work!

pull/488/head
liyubrt 2024-10-22 19:51:52 -07:00
parent e1277af2ba
commit 400983cf83
9 changed files with 86 additions and 23 deletions

View File

@ -5,3 +5,4 @@
from .image_net import ImageNet
from .image_net_22k import ImageNet22k
from .custom_data import CustomData

View File

@ -0,0 +1,41 @@
from enum import Enum
import logging
import os
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
from PIL import Image
from torchvision.datasets import DatasetFolder
from .extended import ExtendedVisionDataset
logger = logging.getLogger("dinov2")
def pil_loader(p):
return Image.open(p).convert("RGB")
class CustomData(ExtendedVisionDataset):
def __init__(
self,
*,
root: str,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super().__init__(root, transforms, transform, target_transform)
self.root = root
self.data = DatasetFolder(root, loader=pil_loader, extensions=["jpg"])
def get_image_data(self, index: int) -> bytes:
return self.data[index][0]
def get_target(self, index: int) -> Optional[int]:
return 0
def __len__(self) -> int:
return len(self.data)

View File

@ -22,8 +22,9 @@ class ExtendedVisionDataset(VisionDataset):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
try:
image_data = self.get_image_data(index)
image = ImageDataDecoder(image_data).decode()
# image_data = self.get_image_data(index)
# image = ImageDataDecoder(image_data).decode()
image = self.get_image_data(index)
except Exception as e:
raise RuntimeError(f"can not read image for sample {index}") from e
target = self.get_target(index)

View File

@ -10,7 +10,7 @@ from typing import Any, Callable, List, Optional, TypeVar
import torch
from torch.utils.data import Sampler
from .datasets import ImageNet, ImageNet22k
from .datasets import ImageNet, ImageNet22k, CustomData
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
@ -58,6 +58,8 @@ def _parse_dataset_str(dataset_str: str):
kwargs["split"] = ImageNet.Split[kwargs["split"]]
elif name == "ImageNet22k":
class_ = ImageNet22k
elif name == "CustomData":
class_ = CustomData
else:
raise ValueError(f'Unsupported dataset "{name}"')

View File

@ -63,11 +63,14 @@ def is_sharded_fsdp(x):
def free_if_fsdp(x):
if is_sharded_fsdp(x):
handles = x._handles
true_list = [True for h in handles]
_reshard(x, handles, true_list)
# if is_sharded_fsdp(x): # works for torch 2.0.0
# handles = x._handles
# true_list = [True for h in handles]
# _reshard(x, handles, true_list)
# make changes according to https://github.com/facebookresearch/dinov2/pull/281/files
if is_sharded_fsdp(x) and x._has_params:
handle = x._handle
_reshard(x, handle, True)
def get_fsdp_modules(x):
return FSDP.fsdp_modules(x)

View File

@ -172,7 +172,8 @@ def get_attn_bias_and_cat(x_list, branges=None):
for b, x in zip(batch_sizes, x_list):
for _ in range(b):
seqlens.append(x.shape[1])
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
# attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
attn_bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens(seqlens) # works for newer verions of xformers
attn_bias._batch_sizes = batch_sizes
attn_bias_cache[all_shapes] = attn_bias

View File

@ -61,10 +61,16 @@ def get_args_parser(
help="Partition where to submit",
)
parser.add_argument(
"--use-volta32",
action="store_true",
help="Request V100-32GB GPUs",
"--mem-per-gpu",
default="60G",
type=str,
help="Memory per GPU",
)
# parser.add_argument(
# "--use-volta32",
# action="store_true",
# help="Request V100-32GB GPUs",
# )
parser.add_argument(
"--comment",
default="",
@ -97,8 +103,8 @@ def submit_jobs(task_class, args, name: str):
executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30)
kwargs = {}
if args.use_volta32:
kwargs["slurm_constraint"] = "volta32gb"
# if args.use_volta32:
# kwargs["slurm_constraint"] = "volta32gb"
if args.comment:
kwargs["slurm_comment"] = args.comment
if args.exclude:
@ -110,10 +116,11 @@ def submit_jobs(task_class, args, name: str):
timeout_min=args.timeout, # max is 60 * 72
slurm_signal_delay_s=120,
slurm_partition=args.partition,
mem_per_gpu=args.mem_per_gpu,
**kwargs,
)
executor.update_parameters(name=name, **executor_params)
print(args, executor_params)
task = task_class(args)
job = executor.submit(task)

View File

@ -262,7 +262,8 @@ class SSLMetaArch(nn.Module):
]
# 2: run
_attn_bias, cat_inputs = fmha.BlockDiagonalMask.from_tensor_list(inputs_for_student_head_list)
# _attn_bias, cat_inputs = fmha.BlockDiagonalMask.from_tensor_list(inputs_for_student_head_list)
_attn_bias, cat_inputs = fmha.attn_bias.BlockDiagonalMask.from_tensor_list(inputs_for_student_head_list)
outputs_list = _attn_bias.split(self.student.dino_head(cat_inputs))
# 3a: local crops cls tokens
@ -348,9 +349,15 @@ class SSLMetaArch(nn.Module):
def fsdp_synchronize_streams(self):
if self.need_to_synchronize_fsdp_streams:
torch.cuda.synchronize()
self.student.dino_head._streams = (
self.teacher.dino_head._streams
) = self.student.backbone._streams = self.teacher.backbone._streams
# self.student.dino_head._streams = (
# self.teacher.dino_head._streams
# ) = self.student.backbone._streams = self.teacher.backbone._streams
# make changes according to https://github.com/facebookresearch/dinov2/pull/281/files
for attr in {"_unshard_stream", "_post_backward_stream", "_pre_unshard_stream", "_all_reduce_stream", "_default_stream"}:
stream = getattr(self.teacher.backbone, attr)
setattr(self.student.dino_head, attr, stream)
setattr(self.teacher.dino_head, attr, stream)
setattr(self.student.backbone, attr, stream)
self.need_to_synchronize_fsdp_streams = False
def update_teacher(self, m):

View File

@ -76,20 +76,20 @@ def get_slurm_executor_parameters(
) -> Dict[str, Any]:
# create default parameters
params = {
"mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
# "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
"gpus_per_node": num_gpus_per_node,
"tasks_per_node": num_gpus_per_node, # one task per GPU
"cpus_per_task": 10,
"cpus_per_gpu": 7,
"nodes": nodes,
"slurm_partition": get_slurm_partition(cluster_type),
}
# apply cluster-specific adjustments
cluster_type = get_cluster_type(cluster_type)
if cluster_type == ClusterType.AWS:
params["cpus_per_task"] = 12
params["cpus_per_gpu"] = 12
del params["mem_gb"]
elif cluster_type == ClusterType.RSC:
params["cpus_per_task"] = 12
params["cpus_per_gpu"] = 12
# set additional parameters / apply overrides
params.update(kwargs)
return params