make sure training works with torchrun on the slurm cluster

pull/461/head
Adrian Wolny 2024-08-26 16:59:30 +02:00
parent 2039cf4dfd
commit cd75745578
5 changed files with 22 additions and 32 deletions

View File

@ -85,20 +85,9 @@ class HistoPatchDataset(VisionDataset):
else:
self.file_list = self.file_list[train_count:]
# compute number of patches in the dataset
self.num_patches = 0
self.file_patch_count = {}
small_patch_file_count = 0
for file_path in tqdm(self.file_list):
patch_count = self._load_patches(file_path, return_count=True)
if patch_count != self.internal_patch_count:
small_patch_file_count += 1
assert small_patch_file_count <= 2, f"Multiple files with patch count smaller than {self.internal_patch_count} are not allowed!"
self.num_patches += patch_count
self.file_patch_count[file_path] = patch_count
# sort files by number of patches in descending order
self.file_list = sorted(self.file_list, key=lambda f: self.file_patch_count[f], reverse=True)
# calculate number of patches in the dataset
# there might be a single file with fewer patches than internal_patch_count, but we ignore it
self.num_patches = len(self.file_list) * self.internal_patch_count
# initialize file cache dictionary
self.file_cache = {}
@ -126,9 +115,11 @@ class HistoPatchDataset(VisionDataset):
file_path = self.file_list[file_idx]
if file_path in self.file_cache:
print("Cache hit")
# cache hit
patches = self.file_cache[file_path]
else:
print("Cache miss")
# cache miss: load patches from the file
if len(self.file_cache) >= self.cache_size:
# remove the first element from the cache
@ -138,6 +129,8 @@ class HistoPatchDataset(VisionDataset):
patches = self._load_patches(file_path)
self.file_cache[file_path] = patches
# adjust patch_idx if necessary
patch_idx = patch_idx % patches.shape[0]
# convert patch to image
img = patches[patch_idx]
img = np.transpose(img, (1, 2, 0))

View File

@ -50,16 +50,8 @@ class HistoInfiniteDistributedSampler(Sampler[T_co]):
self.file_patch_indices = {}
for file_id in self.file_ids:
start_idx = file_id * self.dataset.internal_patch_count
file_path = self.dataset.file_list[file_id]
patch_count = self.dataset.file_patch_count[file_path]
end_idx = start_idx + patch_count
end_idx = start_idx + self.dataset.internal_patch_count
indices = list(range(start_idx, end_idx))
if patch_count < self.dataset.internal_patch_count:
# if the file has fewer patches than internal_patch_count, duplicate patches from the same file
padding_size = self.dataset.internal_patch_count - patch_count
padding_indices = np.random.choice(range(start_idx, end_idx), padding_size, replace=True)
indices.extend(padding_indices)
self.file_patch_indices[file_id] = indices
def __iter__(self) -> Iterator[T_co]:
@ -81,8 +73,6 @@ class HistoInfiniteDistributedSampler(Sampler[T_co]):
for file_id in file_ids:
indices.extend(self.file_patch_indices[file_id])
assert len(indices) == self.num_samples
# shuffle indices if necessary
if self.shuffle:
g = torch.Generator().manual_seed(self.seed)

View File

@ -157,8 +157,9 @@ class _TorchDistributedEnvironment:
self.local_rank = -1
self.local_world_size = -1
if _is_slurm_job_process():
return self._set_from_slurm_env()
# prevent messing with the environment!
#if _is_slurm_job_process():
# return self._set_from_slurm_env()
env_vars = _collect_env_vars()
if not env_vars:
@ -250,6 +251,8 @@ def enable(*, set_cuda_current_device: bool = True, overwrite: bool = False, all
raise RuntimeError("Distributed mode has already been enabled")
torch_env = _TorchDistributedEnvironment()
torch_env.export(overwrite=overwrite)
# print OS env for debugging
print(os.environ)
if set_cuda_current_device:
torch.cuda.set_device(torch_env.local_rank)

View File

@ -7,6 +7,7 @@ import argparse
import logging
import math
import os
import time
from functools import partial
from fvcore.common.checkpoint import PeriodicCheckpointer
@ -243,8 +244,9 @@ def do_train(cfg, model, resume=False):
last_layer_lr = last_layer_lr_schedule[iteration]
apply_optim_scheduler(optimizer, lr, wd, last_layer_lr)
# measure the time of single iteration
s0 = time.perf_counter()
# compute losses
optimizer.zero_grad(set_to_none=True)
loss_dict = model.forward_backward(data, teacher_temp=teacher_temp)
@ -262,6 +264,8 @@ def do_train(cfg, model, resume=False):
for v in model.student.values():
v.clip_grad_norm_(cfg.optim.clip_grad)
optimizer.step()
# log iter time
logger.info(f"iter time: {time.perf_counter() - s0:.2f}")
# perform teacher EMA update

View File

@ -76,9 +76,9 @@ def get_slurm_executor_parameters(
) -> Dict[str, Any]:
# create default parameters
params = {
"mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
"gpus_per_node": num_gpus_per_node,
"tasks_per_node": num_gpus_per_node, # one task per GPU
#"mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
#"gpus_per_node": num_gpus_per_node,
#"tasks_per_node": num_gpus_per_node, # one task per GPU
"cpus_per_task": 10,
"nodes": nodes,
"slurm_partition": get_slurm_partition(cluster_type),
@ -87,7 +87,7 @@ def get_slurm_executor_parameters(
cluster_type = get_cluster_type(cluster_type)
if cluster_type == ClusterType.AWS:
params["cpus_per_task"] = 12
del params["mem_gb"]
#del params["mem_gb"]
elif cluster_type == ClusterType.RSC:
params["cpus_per_task"] = 12
# set additional parameters / apply overrides