make sure training works with torchrun on the slurm cluster
parent
2039cf4dfd
commit
cd75745578
|
@ -85,20 +85,9 @@ class HistoPatchDataset(VisionDataset):
|
|||
else:
|
||||
self.file_list = self.file_list[train_count:]
|
||||
|
||||
# compute number of patches in the dataset
|
||||
self.num_patches = 0
|
||||
self.file_patch_count = {}
|
||||
small_patch_file_count = 0
|
||||
for file_path in tqdm(self.file_list):
|
||||
patch_count = self._load_patches(file_path, return_count=True)
|
||||
if patch_count != self.internal_patch_count:
|
||||
small_patch_file_count += 1
|
||||
assert small_patch_file_count <= 2, f"Multiple files with patch count smaller than {self.internal_patch_count} are not allowed!"
|
||||
self.num_patches += patch_count
|
||||
self.file_patch_count[file_path] = patch_count
|
||||
|
||||
# sort files by number of patches in descending order
|
||||
self.file_list = sorted(self.file_list, key=lambda f: self.file_patch_count[f], reverse=True)
|
||||
# calculate number of patches in the dataset
|
||||
# there might be a single file with fewer patches than internal_patch_count, but we ignore it
|
||||
self.num_patches = len(self.file_list) * self.internal_patch_count
|
||||
|
||||
# initialize file cache dictionary
|
||||
self.file_cache = {}
|
||||
|
@ -126,9 +115,11 @@ class HistoPatchDataset(VisionDataset):
|
|||
|
||||
file_path = self.file_list[file_idx]
|
||||
if file_path in self.file_cache:
|
||||
print("Cache hit")
|
||||
# cache hit
|
||||
patches = self.file_cache[file_path]
|
||||
else:
|
||||
print("Cache miss")
|
||||
# cache miss: load patches from the file
|
||||
if len(self.file_cache) >= self.cache_size:
|
||||
# remove the first element from the cache
|
||||
|
@ -138,6 +129,8 @@ class HistoPatchDataset(VisionDataset):
|
|||
patches = self._load_patches(file_path)
|
||||
self.file_cache[file_path] = patches
|
||||
|
||||
# adjust patch_idx if necessary
|
||||
patch_idx = patch_idx % patches.shape[0]
|
||||
# convert patch to image
|
||||
img = patches[patch_idx]
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
|
|
|
@ -50,16 +50,8 @@ class HistoInfiniteDistributedSampler(Sampler[T_co]):
|
|||
self.file_patch_indices = {}
|
||||
for file_id in self.file_ids:
|
||||
start_idx = file_id * self.dataset.internal_patch_count
|
||||
file_path = self.dataset.file_list[file_id]
|
||||
patch_count = self.dataset.file_patch_count[file_path]
|
||||
end_idx = start_idx + patch_count
|
||||
end_idx = start_idx + self.dataset.internal_patch_count
|
||||
indices = list(range(start_idx, end_idx))
|
||||
if patch_count < self.dataset.internal_patch_count:
|
||||
# if the file has fewer patches than internal_patch_count, duplicate patches from the same file
|
||||
padding_size = self.dataset.internal_patch_count - patch_count
|
||||
padding_indices = np.random.choice(range(start_idx, end_idx), padding_size, replace=True)
|
||||
indices.extend(padding_indices)
|
||||
|
||||
self.file_patch_indices[file_id] = indices
|
||||
|
||||
def __iter__(self) -> Iterator[T_co]:
|
||||
|
@ -81,8 +73,6 @@ class HistoInfiniteDistributedSampler(Sampler[T_co]):
|
|||
for file_id in file_ids:
|
||||
indices.extend(self.file_patch_indices[file_id])
|
||||
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
# shuffle indices if necessary
|
||||
if self.shuffle:
|
||||
g = torch.Generator().manual_seed(self.seed)
|
||||
|
|
|
@ -157,8 +157,9 @@ class _TorchDistributedEnvironment:
|
|||
self.local_rank = -1
|
||||
self.local_world_size = -1
|
||||
|
||||
if _is_slurm_job_process():
|
||||
return self._set_from_slurm_env()
|
||||
# prevent messing with the environment!
|
||||
#if _is_slurm_job_process():
|
||||
# return self._set_from_slurm_env()
|
||||
|
||||
env_vars = _collect_env_vars()
|
||||
if not env_vars:
|
||||
|
@ -250,6 +251,8 @@ def enable(*, set_cuda_current_device: bool = True, overwrite: bool = False, all
|
|||
raise RuntimeError("Distributed mode has already been enabled")
|
||||
torch_env = _TorchDistributedEnvironment()
|
||||
torch_env.export(overwrite=overwrite)
|
||||
# print OS env for debugging
|
||||
print(os.environ)
|
||||
|
||||
if set_cuda_current_device:
|
||||
torch.cuda.set_device(torch_env.local_rank)
|
||||
|
|
|
@ -7,6 +7,7 @@ import argparse
|
|||
import logging
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
from fvcore.common.checkpoint import PeriodicCheckpointer
|
||||
|
@ -243,8 +244,9 @@ def do_train(cfg, model, resume=False):
|
|||
last_layer_lr = last_layer_lr_schedule[iteration]
|
||||
apply_optim_scheduler(optimizer, lr, wd, last_layer_lr)
|
||||
|
||||
# measure the time of single iteration
|
||||
s0 = time.perf_counter()
|
||||
# compute losses
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss_dict = model.forward_backward(data, teacher_temp=teacher_temp)
|
||||
|
||||
|
@ -262,6 +264,8 @@ def do_train(cfg, model, resume=False):
|
|||
for v in model.student.values():
|
||||
v.clip_grad_norm_(cfg.optim.clip_grad)
|
||||
optimizer.step()
|
||||
# log iter time
|
||||
logger.info(f"iter time: {time.perf_counter() - s0:.2f}")
|
||||
|
||||
# perform teacher EMA update
|
||||
|
||||
|
|
|
@ -76,9 +76,9 @@ def get_slurm_executor_parameters(
|
|||
) -> Dict[str, Any]:
|
||||
# create default parameters
|
||||
params = {
|
||||
"mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
|
||||
"gpus_per_node": num_gpus_per_node,
|
||||
"tasks_per_node": num_gpus_per_node, # one task per GPU
|
||||
#"mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
|
||||
#"gpus_per_node": num_gpus_per_node,
|
||||
#"tasks_per_node": num_gpus_per_node, # one task per GPU
|
||||
"cpus_per_task": 10,
|
||||
"nodes": nodes,
|
||||
"slurm_partition": get_slurm_partition(cluster_type),
|
||||
|
@ -87,7 +87,7 @@ def get_slurm_executor_parameters(
|
|||
cluster_type = get_cluster_type(cluster_type)
|
||||
if cluster_type == ClusterType.AWS:
|
||||
params["cpus_per_task"] = 12
|
||||
del params["mem_gb"]
|
||||
#del params["mem_gb"]
|
||||
elif cluster_type == ClusterType.RSC:
|
||||
params["cpus_per_task"] = 12
|
||||
# set additional parameters / apply overrides
|
||||
|
|
Loading…
Reference in New Issue