From 0cbb06fea64e27966fe73aebbd9f77853f60a144 Mon Sep 17 00:00:00 2001 From: Smriti Dahal <93288516+smritidahal653@users.noreply.github.com> Date: Wed, 17 Jul 2024 14:37:17 -0700 Subject: [PATCH] feat: [SKU-modularization] Get SKU Handler (#518) **Reason for Change**: - Reading the CLOUD_PROVIDER environment variable and initializing the respective SKUHandler to access the correct list of gpu skus **Requirements** - [ ] added unit tests and e2e tests (if applicable). **Issue Fixed**: **Notes for Reviewers**: This PR does not change anything in the current implementation to keep things clean. The next PR will get rid of sku_config from v1alpha1 and carry out necessary refactoring and test updates. --- pkg/sku/cloud_sku_handler.go | 15 +++++++++++++++ pkg/utils/common.go | 18 ++++++++++++++++++ pkg/utils/consts/consts.go | 2 ++ 3 files changed, 35 insertions(+) diff --git a/pkg/sku/cloud_sku_handler.go b/pkg/sku/cloud_sku_handler.go index afc555fcf..762da2ac1 100644 --- a/pkg/sku/cloud_sku_handler.go +++ b/pkg/sku/cloud_sku_handler.go @@ -3,6 +3,10 @@ package sku +import ( + "github.com/azure/kaito/pkg/utils/consts" +) + type CloudSKUHandler interface { GetSupportedSKUs() []string GetGPUConfigs() map[string]GPUConfig @@ -14,3 +18,14 @@ type GPUConfig struct { GPUMem int GPUModel string } + +func GetCloudSKUHandler(cloud string) CloudSKUHandler { + switch cloud { + case consts.AzureCloudName: + return NewAzureSKUHandler() + case consts.AWSCloudName: + return NewAwsSKUHandler() + default: + return nil + } +} diff --git a/pkg/utils/common.go b/pkg/utils/common.go index f29710c43..63772896d 100644 --- a/pkg/utils/common.go +++ b/pkg/utils/common.go @@ -10,7 +10,9 @@ import ( "gopkg.in/yaml.v2" "k8s.io/apimachinery/pkg/runtime" + "knative.dev/pkg/apis" + "github.com/azure/kaito/pkg/sku" "github.com/azure/kaito/pkg/utils/consts" ) @@ -101,3 +103,19 @@ func GetReleaseNamespace() (string, error) { } return "", fmt.Errorf("failed to determine release namespace from file %s and env var %s", namespaceFilePath, consts.DefaultReleaseNamespaceEnvVar) } + +func GetSKUHandler() (sku.CloudSKUHandler, error) { + // Get the cloud provider from the environment + provider := os.Getenv("CLOUD_PROVIDER") + + if provider == "" { + return nil, apis.ErrMissingField("CLOUD_PROVIDER environment variable must be set") + } + // Select the correct SKU handler based on the cloud provider + skuHandler := sku.GetCloudSKUHandler(provider) + if skuHandler == nil { + return nil, apis.ErrInvalidValue(fmt.Sprintf("Unsupported cloud provider %s", provider), "CLOUD_PROVIDER") + } + + return skuHandler, nil +} diff --git a/pkg/utils/consts/consts.go b/pkg/utils/consts/consts.go index d6ed62760..d15de22c5 100644 --- a/pkg/utils/consts/consts.go +++ b/pkg/utils/consts/consts.go @@ -8,4 +8,6 @@ const ( WorkspaceFinalizer = "workspace.finalizer.kaito.sh" DefaultReleaseNamespaceEnvVar = "RELEASE_NAMESPACE" FeatureFlagKarpenter = "Karpenter" + AzureCloudName = "azure" + AWSCloudName = "aws" )