You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Is this a new feature, an improvement, or a change to existing functionality?
New Feature
How would you describe the priority of this feature request
Medium
Please provide a clear description of problem you would like to solve.
I use Modulus DistributedManager with SLURM. Right now, DistributedManager sets the local_rank based on the number of local processes on the node (this line).
local_rank = int(os.environ.get("SLURM_LOCALID"))
This line) then sets the device based on the local_rank.
manager._device = torch.device(
f"cuda:{manager.local_rank}" if torch.cuda.is_available() else "cpu"
)
Notably, this line breaks if "SLURM_LOCALID" is greater than torch.cuda.device_count().
In my use case, however, I need to use the SBATCH —-gpu-bind:map_gpus:0,1,2,3 flag on a node with 4 GPUs. With 4 processes per node and 4 GPUs per node, each process only sees 1 device called cuda:0, though that name actually refers to 4 different GPUs. (This forum explains why I need to use this flag.)
There may be other use cases where the number of local processes specified through SLURM may not equal the number of GPUs accessible (e.g. running FourCastNet with 4 GPUs and 1 process per GPU, but analyzing the output with more processes).
My request would be to add a flag to DistributedManager, through which I could specify that the behavior below is desired for SLURM as well.
This ensures that torch.device is not called on a device that can't be accessed.
Describe any alternatives you have considered
Without a flag, DistributedManager.initialize() returns an error because torch.device is used to access a device that is not available. I could make an equivalent for DistributedManager, or I could create a subclass of DistributedManager that overrides the initialize_slurm method. Let me know if that would be the preferred solution, and I can continue with my fix on my local end.
The text was updated successfully, but these errors were encountered:
@ankurmahesh This is an interesting use case that I don't think we've encountered before. Am I understanding this correctly that you need this because you are using a TorchScript serialized model that has cuda:0 baked in as the device?
There are two solutions I can think of in this case:
Set the SLURM_LOCALID variable to 0 for all ranks before calling DistributedManager.initialize() or
Add this feature to always just do this to get the device ID:
The second option will allow you to use the --gpu-bind argument in a SLURM environment or you could also just set CUDA_VISIBLE_DEVICES=0 manually for all ranks.
Is this a new feature, an improvement, or a change to existing functionality?
New Feature
How would you describe the priority of this feature request
Medium
Please provide a clear description of problem you would like to solve.
I use Modulus
DistributedManager
with SLURM. Right now,DistributedManager
sets thelocal_rank
based on the number of local processes on the node (this line).This line) then sets the device based on the
local_rank
.Notably, this line breaks if "SLURM_LOCALID" is greater than torch.cuda.device_count().
In my use case, however, I need to use the
SBATCH —-gpu-bind:map_gpus:0,1,2,3
flag on a node with 4 GPUs. With 4 processes per node and 4 GPUs per node, each process only sees 1 device called cuda:0, though that name actually refers to 4 different GPUs. (This forum explains why I need to use this flag.)There may be other use cases where the number of local processes specified through SLURM may not equal the number of GPUs accessible (e.g. running FourCastNet with 4 GPUs and 1 process per GPU, but analyzing the output with more processes).
My request would be to add a flag to
DistributedManager
, through which I could specify that the behavior below is desired for SLURM as well.This ensures that torch.device is not called on a device that can't be accessed.
Describe any alternatives you have considered
Without a flag,
DistributedManager.initialize()
returns an error because torch.device is used to access a device that is not available. I could make an equivalent for DistributedManager, or I could create a subclass of DistributedManager that overrides theinitialize_slurm
method. Let me know if that would be the preferred solution, and I can continue with my fix on my local end.The text was updated successfully, but these errors were encountered: