Skip to content

Commit

Permalink
feat: [SKU-modularization] Get SKU Handler (kaito-project#518)
Browse files Browse the repository at this point in the history
**Reason for Change**:
- Reading the CLOUD_PROVIDER environment variable and initializing the
respective SKUHandler to access the correct list of gpu skus

**Requirements**

- [ ] added unit tests and e2e tests (if applicable).

**Issue Fixed**:
<!-- If this PR fixes GitHub issue 4321, add "Fixes #4321" to the next
line. -->

**Notes for Reviewers**:
This PR does not change anything in the current implementation to keep
things clean. The next PR will get rid of sku_config from v1alpha1 and
carry out necessary refactoring and test updates.
  • Loading branch information
smritidahal653 authored Jul 17, 2024
1 parent cbfaea8 commit 0cbb06f
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 0 deletions.
15 changes: 15 additions & 0 deletions pkg/sku/cloud_sku_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

package sku

import (
"github.com/azure/kaito/pkg/utils/consts"
)

type CloudSKUHandler interface {
GetSupportedSKUs() []string
GetGPUConfigs() map[string]GPUConfig
Expand All @@ -14,3 +18,14 @@ type GPUConfig struct {
GPUMem int
GPUModel string
}

func GetCloudSKUHandler(cloud string) CloudSKUHandler {
switch cloud {
case consts.AzureCloudName:
return NewAzureSKUHandler()
case consts.AWSCloudName:
return NewAwsSKUHandler()
default:
return nil
}
}
18 changes: 18 additions & 0 deletions pkg/utils/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import (

"gopkg.in/yaml.v2"
"k8s.io/apimachinery/pkg/runtime"
"knative.dev/pkg/apis"

"github.com/azure/kaito/pkg/sku"
"github.com/azure/kaito/pkg/utils/consts"
)

Expand Down Expand Up @@ -101,3 +103,19 @@ func GetReleaseNamespace() (string, error) {
}
return "", fmt.Errorf("failed to determine release namespace from file %s and env var %s", namespaceFilePath, consts.DefaultReleaseNamespaceEnvVar)
}

func GetSKUHandler() (sku.CloudSKUHandler, error) {
// Get the cloud provider from the environment
provider := os.Getenv("CLOUD_PROVIDER")

if provider == "" {
return nil, apis.ErrMissingField("CLOUD_PROVIDER environment variable must be set")
}
// Select the correct SKU handler based on the cloud provider
skuHandler := sku.GetCloudSKUHandler(provider)
if skuHandler == nil {
return nil, apis.ErrInvalidValue(fmt.Sprintf("Unsupported cloud provider %s", provider), "CLOUD_PROVIDER")
}

return skuHandler, nil
}
2 changes: 2 additions & 0 deletions pkg/utils/consts/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ const (
WorkspaceFinalizer = "workspace.finalizer.kaito.sh"
DefaultReleaseNamespaceEnvVar = "RELEASE_NAMESPACE"
FeatureFlagKarpenter = "Karpenter"
AzureCloudName = "azure"
AWSCloudName = "aws"
)

0 comments on commit 0cbb06f

Please sign in to comment.