-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[sarc-386] Acquire node_to_gpu and gpu_billing from cached slurm conf…
… files. (#138) * [sarc-386] Acquire node_to_gpu and gpu_billing from cached slurm conf files. * Display an error message and halt script for cluster mila which is not yet correctly supported. --------- Co-authored-by: Bruno Carrez <[email protected]>
- Loading branch information
1 parent
85bd044
commit 87ad1b1
Showing
17 changed files
with
821 additions
and
110 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,272 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from dataclasses import dataclass | ||
from typing import Dict, List | ||
|
||
from hostlist import expand_hostlist | ||
from simple_parsing import field | ||
|
||
from sarc.cache import CachePolicy, with_cache | ||
from sarc.client.gpumetrics import _gpu_billing_collection | ||
from sarc.config import config | ||
from sarc.jobs.node_gpu_mapping import _node_gpu_mapping_collection | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class AcquireSlurmConfig: | ||
cluster_name: str = field(alias=["-c"]) | ||
day: str = field(alias=["-d"]) | ||
|
||
def execute(self) -> int: | ||
if self.cluster_name == "mila": | ||
logger.error("Cluster `mila` not yet supported.") | ||
return -1 | ||
|
||
parser = SlurmConfigParser(self.cluster_name, self.day) | ||
slurm_conf = parser.get_slurm_config() | ||
_gpu_billing_collection().save_gpu_billing( | ||
self.cluster_name, self.day, slurm_conf.gpu_to_billing | ||
) | ||
_node_gpu_mapping_collection().save_node_gpu_mapping( | ||
self.cluster_name, self.day, slurm_conf.node_to_gpu | ||
) | ||
return 0 | ||
|
||
|
||
class SlurmConfigParser: | ||
def __init__(self, cluster_name: str, day: str): | ||
self.cluster_name = cluster_name | ||
self.day = day | ||
|
||
def get_slurm_config(self) -> SlurmConfig: | ||
return with_cache( | ||
self._get_slurm_conf, | ||
subdirectory="slurm_conf", | ||
key=self._cache_key, | ||
formatter=self, | ||
)(cache_policy=CachePolicy.always) | ||
|
||
def _get_slurm_conf(self): | ||
raise RuntimeError( | ||
f"Please add cluster slurm.conf file into cache, at location: " | ||
f"{config().cache}/slurm_conf/{self._cache_key()}" | ||
) | ||
|
||
def _cache_key(self): | ||
return f"slurm.{self.cluster_name}.{self.day}.conf" | ||
|
||
def load(self, file) -> SlurmConfig: | ||
""" | ||
Parse cached slurm conf file and return a SlurmConfig object | ||
containing node_to_gpu and gpu_to_billing. | ||
""" | ||
|
||
partitions: List[Partition] = [] | ||
node_to_gpu = {} | ||
|
||
# Parse lines: extract partitions and node_to_gpu | ||
for line_number, line in enumerate(file): | ||
line = line.strip() | ||
if line.startswith("PartitionName="): | ||
partitions.append( | ||
Partition( | ||
line_number=line_number + 1, | ||
line=line, | ||
info=dict( | ||
option.split("=", maxsplit=1) for option in line.split() | ||
), | ||
) | ||
) | ||
elif line.startswith("NodeName="): | ||
nodes_config = dict( | ||
option.split("=", maxsplit=1) for option in line.split() | ||
) | ||
gpu_type = nodes_config.get("Gres") | ||
if gpu_type: | ||
node_to_gpu.update( | ||
{ | ||
node_name: gpu_type | ||
for node_name in expand_hostlist(nodes_config["NodeName"]) | ||
} | ||
) | ||
|
||
# Parse partitions: extract gpu_to_billing | ||
gpu_to_billing = self._parse_gpu_to_billing(partitions, node_to_gpu) | ||
|
||
# Return parsed data | ||
return SlurmConfig(node_to_gpu=node_to_gpu, gpu_to_billing=gpu_to_billing) | ||
|
||
def _parse_gpu_to_billing( | ||
self, partitions: List[Partition], node_to_gpu: Dict[str, str] | ||
) -> Dict[str, float]: | ||
|
||
# Mapping of GPU to partition billing. | ||
# ALlow to check that inferred billing for a GPU is the same across partitions. | ||
# If not, an error will be raised with additional info about involved partitions. | ||
gpu_to_partition_billing: Dict[str, PartitionGPUBilling] = {} | ||
|
||
# Collection for all GPUs found in partition nodes. | ||
# We will later iterate on this collection to resolve any GPU without billing. | ||
all_partition_node_gpus = set() | ||
|
||
for partition in partitions: | ||
# Get all GPUs in partition nodes and all partition GPU billings. | ||
local_gres, local_gpu_to_billing = ( | ||
partition.get_gpus_and_partition_billings(node_to_gpu) | ||
) | ||
|
||
# Merge local GPUs into global partition node GPUs. | ||
all_partition_node_gpus.update(local_gres) | ||
|
||
# Merge local GPU billings into global GPU billings | ||
for gpu_type, value in local_gpu_to_billing.items(): | ||
new_billing = PartitionGPUBilling( | ||
gpu_type=gpu_type, value=value, partition=partition | ||
) | ||
if gpu_type not in gpu_to_partition_billing: | ||
# New GPU found, add it | ||
gpu_to_partition_billing[gpu_type] = new_billing | ||
elif gpu_to_partition_billing[gpu_type].value != value: | ||
# GPU already found, with a different billing. Problem. | ||
raise InconsistentGPUBillingError( | ||
gpu_type, gpu_to_partition_billing[gpu_type], new_billing | ||
) | ||
|
||
# Generate GPU->billing mapping | ||
gpu_to_billing = { | ||
gpu_type: billing.value | ||
for gpu_type, billing in gpu_to_partition_billing.items() | ||
} | ||
|
||
# Resolve GPUs without billing | ||
for gpu_desc in all_partition_node_gpus: | ||
if gpu_desc not in gpu_to_billing: | ||
if gpu_desc.startswith("gpu:") and gpu_desc.count(":") == 2: | ||
# GPU resource with format `gpu:<real-gpu-type>:<gpu-count>` | ||
_, gpu_type, gpu_count = gpu_desc.split(":") | ||
if gpu_type in gpu_to_billing: | ||
billing = gpu_to_billing[gpu_type] * int(gpu_count) | ||
gpu_to_billing[gpu_desc] = billing | ||
logger.info(f"Inferred billing for {gpu_desc} := {billing}") | ||
else: | ||
logger.warning( | ||
f"Cannot find GPU billing for GPU type {gpu_type} in GPU resource {gpu_desc}" | ||
) | ||
else: | ||
logger.warning(f"Cannot infer billing for GPU: {gpu_desc}") | ||
|
||
# We can finally return GPU->billing mapping. | ||
return gpu_to_billing | ||
|
||
|
||
@dataclass | ||
class SlurmConfig: | ||
"""Parsed data from slurm config file""" | ||
|
||
node_to_gpu: Dict[str, str] | ||
gpu_to_billing: Dict[str, float] | ||
|
||
|
||
@dataclass | ||
class Partition: | ||
"""Partition entry in slurm config file""" | ||
|
||
line_number: int | ||
line: str | ||
info: Dict[str, str] | ||
|
||
def get_gpus_and_partition_billings(self, node_to_gpu: Dict[str, str]): | ||
""" | ||
Parse and return: | ||
- partition node GPUs | ||
- partition GPU billings inferred from field `TRESBillingWeights` | ||
""" | ||
|
||
# Get partition node GPUs | ||
local_gres = self._get_node_gpus(node_to_gpu) | ||
|
||
# Get GPU weights from TRESBillingWeights | ||
tres_billing_weights = dict( | ||
option.split("=", maxsplit=1) | ||
for option in self.info.get("TRESBillingWeights", "").split(",") | ||
if option | ||
) | ||
gpu_weights = { | ||
key: value | ||
for key, value in tres_billing_weights.items() | ||
if key.startswith("GRES/gpu") | ||
} | ||
|
||
# Parse local GPU billings | ||
local_gpu_to_billing = {} | ||
for key, value in gpu_weights.items(): | ||
value = float(value) | ||
if key == "GRES/gpu": | ||
if len(gpu_weights) == 1: | ||
# We only have `GRES/gpu=<value>` | ||
# Let's map value to each GPU found in partition nodes | ||
local_gpu_to_billing.update( | ||
{gpu_name: value for gpu_name in local_gres} | ||
) | ||
else: | ||
# We have both `GRES/gpu=<value>` and at least one `GRES/gpu:a_gpu=a_value`. | ||
# Ambiguous case, cannot map `GRES/gpu=<value>` to a specific GPU. | ||
logger.debug( | ||
f"[line {self.line_number}] " | ||
f"Ignored ambiguous GPU billing (cannot match to a specific GPU): `{key}={value}` " | ||
f"| partition: {self.info['PartitionName']} " | ||
# f"| nodes: {partition.info['Nodes']}, " | ||
f"| nodes GPUs: {', '.join(local_gres)} " | ||
f"| TRESBillingWeights: {self.info['TRESBillingWeights']}" | ||
) | ||
else: | ||
# We have `GRES/gpu:a_gpu=a_value`. | ||
# We can parse. | ||
_, gpu_name = key.split(":", maxsplit=1) | ||
local_gpu_to_billing[gpu_name] = value | ||
|
||
return local_gres, local_gpu_to_billing | ||
|
||
def _get_node_gpus(self, node_to_gpu: Dict[str, str]) -> List[str]: | ||
"""Return all GPUs found in partition nodes""" | ||
return sorted( | ||
{ | ||
gres | ||
for node_name in expand_hostlist(self.info["Nodes"]) | ||
for gres in node_to_gpu.get(node_name, "").split(",") | ||
if gres | ||
} | ||
) | ||
|
||
|
||
@dataclass | ||
class PartitionGPUBilling: | ||
"""Represents a GPU billing found in a specific partition entry.""" | ||
|
||
partition: Partition | ||
gpu_type: str | ||
value: float | ||
|
||
|
||
class InconsistentGPUBillingError(Exception): | ||
def __init__( | ||
self, | ||
gpu_type: str, | ||
prev_billing: PartitionGPUBilling, | ||
new_billing: PartitionGPUBilling, | ||
): | ||
super().__init__( | ||
f"\n" | ||
f"GPU billing differs.\n" | ||
f"GPU name: {gpu_type}\n" | ||
f"Previous value: {prev_billing.value}\n" | ||
f"From line: {prev_billing.partition.line_number}\n" | ||
f"{prev_billing.partition.line}\n" | ||
f"\n" | ||
f"New value: {new_billing.value}\n" | ||
f"From line: {new_billing.partition.line_number}\n" | ||
f"{new_billing.partition.line}\n" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.