97 lines
3.0 KiB
Python
97 lines
3.0 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
#
|
|
# This source code is licensed under the Apache License, Version 2.0
|
|
# found in the LICENSE file in the root directory of this source tree.
|
|
|
|
from enum import Enum
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Optional
|
|
|
|
|
|
class ClusterType(Enum):
|
|
AWS = "aws"
|
|
FAIR = "fair"
|
|
RSC = "rsc"
|
|
|
|
|
|
def _guess_cluster_type() -> ClusterType:
|
|
uname = os.uname()
|
|
if uname.sysname == "Linux":
|
|
if uname.release.endswith("-aws"):
|
|
# Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
|
|
return ClusterType.AWS
|
|
elif uname.nodename.startswith("rsc"):
|
|
# Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
|
|
return ClusterType.RSC
|
|
|
|
return ClusterType.FAIR
|
|
|
|
|
|
def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
|
|
if cluster_type is None:
|
|
return _guess_cluster_type()
|
|
|
|
return cluster_type
|
|
|
|
|
|
def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
|
cluster_type = get_cluster_type(cluster_type)
|
|
if cluster_type is None:
|
|
return None
|
|
|
|
CHECKPOINT_DIRNAMES = {
|
|
ClusterType.AWS: "checkpoints",
|
|
ClusterType.FAIR: "checkpoint",
|
|
ClusterType.RSC: "checkpoint/dino",
|
|
}
|
|
return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
|
|
|
|
|
|
def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
|
checkpoint_path = get_checkpoint_path(cluster_type)
|
|
if checkpoint_path is None:
|
|
return None
|
|
|
|
username = os.environ.get("USER")
|
|
assert username is not None
|
|
return checkpoint_path / username
|
|
|
|
|
|
def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
|
|
cluster_type = get_cluster_type(cluster_type)
|
|
if cluster_type is None:
|
|
return None
|
|
|
|
SLURM_PARTITIONS = {
|
|
ClusterType.AWS: "learnlab",
|
|
ClusterType.FAIR: "learnlab",
|
|
ClusterType.RSC: "learn",
|
|
}
|
|
return SLURM_PARTITIONS[cluster_type]
|
|
|
|
|
|
def get_slurm_executor_parameters(
|
|
nodes: int, num_gpus_per_node: int, mem_per_gpu: str, cluster_type: Optional[ClusterType] = None, **kwargs
|
|
) -> Dict[str, Any]:
|
|
# create default parameters
|
|
params = {
|
|
#"mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
|
|
"slurm_mem_per_gpu": mem_per_gpu,
|
|
"gpus_per_node": num_gpus_per_node,
|
|
"tasks_per_node": num_gpus_per_node, # one task per GPU
|
|
"cpus_per_task": 10,
|
|
"nodes": nodes,
|
|
"slurm_partition": get_slurm_partition(cluster_type),
|
|
}
|
|
# apply cluster-specific adjustments
|
|
cluster_type = get_cluster_type(cluster_type)
|
|
if cluster_type == ClusterType.AWS:
|
|
params["cpus_per_task"] = 12
|
|
del params["mem_gb"]
|
|
elif cluster_type == ClusterType.RSC:
|
|
params["cpus_per_task"] = 12
|
|
# set additional parameters / apply overrides
|
|
params.update(kwargs)
|
|
return params
|