diff --git a/torchx/specs/named_resources_aws.py b/torchx/specs/named_resources_aws.py index 27de01b6b..c3e45b6dc 100644 --- a/torchx/specs/named_resources_aws.py +++ b/torchx/specs/named_resources_aws.py @@ -33,6 +33,8 @@ from torchx.specs.api import Resource +EFA_DEVICE = "vpc.amazonaws.com/efa" + # ecs and ec2 have memtax and currently AWS Batch uses hard memory limits # so we have to account for mem tax when registering these resources for AWS # otherwise the job will be stuck in the jobqueue forever @@ -63,20 +65,32 @@ def aws_p3_16xlarge() -> Resource: def aws_p3dn_24xlarge() -> Resource: return Resource( - cpu=96, gpu=8, memMB=768 * GiB, capabilities={K8S_ITYPE: "p3dn.24xlarge"} + cpu=96, + gpu=8, + memMB=768 * GiB, + capabilities={K8S_ITYPE: "p3dn.24xlarge"}, + devices={EFA_DEVICE: 1}, ) def aws_p4d_24xlarge() -> Resource: return Resource( - cpu=96, gpu=8, memMB=1152 * GiB, capabilities={K8S_ITYPE: "p4d.24xlarge"} + cpu=96, + gpu=8, + memMB=1152 * GiB, + capabilities={K8S_ITYPE: "p4d.24xlarge"}, + devices={EFA_DEVICE: 4}, ) def aws_p4de_24xlarge() -> Resource: # p4de has same cpu, gpu, memMB as p4d but gpu memory is 2x (32GB vs 64GB per GPU) return Resource( - cpu=96, gpu=8, memMB=1152 * GiB, capabilities={K8S_ITYPE: "p4de.24xlarge"} + cpu=96, + gpu=8, + memMB=1152 * GiB, + capabilities={K8S_ITYPE: "p4de.24xlarge"}, + devices={EFA_DEVICE: 4}, ) diff --git a/torchx/specs/test/named_resources_aws_test.py b/torchx/specs/test/named_resources_aws_test.py index 113bee42f..d7d3b9755 100644 --- a/torchx/specs/test/named_resources_aws_test.py +++ b/torchx/specs/test/named_resources_aws_test.py @@ -33,6 +33,7 @@ aws_t3_medium, aws_trn1_2xl, aws_trn1_32xl, + EFA_DEVICE, GiB, K8S_ITYPE, NAMED_RESOURCES, @@ -60,6 +61,7 @@ def test_aws_p3(self) -> None: self.assertEqual(96, p3dn_24.cpu) self.assertEqual(p3_16.gpu, p3dn_24.gpu) self.assertEqual(768 * GiB, p3dn_24.memMB) + self.assertEqual({EFA_DEVICE: 1}, p3dn_24.devices) def test_aws_p4(self) -> None: p4d = aws_p4d_24xlarge() @@ -68,10 +70,12 @@ def test_aws_p4(self) -> None: self.assertEqual(96, p4d.cpu) self.assertEqual(8, p4d.gpu) self.assertEqual(1152 * GiB, p4d.memMB) + self.assertEqual({EFA_DEVICE: 4}, p4d.devices) self.assertEqual(p4de.cpu, p4d.cpu) self.assertEqual(p4de.gpu, p4d.gpu) self.assertEqual(p4de.memMB, p4d.memMB) + self.assertEqual({EFA_DEVICE: 4}, p4de.devices) def test_aws_g4dn(self) -> None: g4d = aws_g4dn_xlarge()