diff --git a/README.md b/README.md index 32b98b10f..fdfe3131c 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,39 @@ -# Kubernetes AI Toolchain Operator(KAITO) +# Kubernetes AI Toolchain Operator (Kaito) [![Go Report Card](https://goreportcard.com/badge/github.com/Azure/kaito)](https://goreportcard.com/report/github.com/Azure/kaito) ![GitHub go.mod Go version](https://img.shields.io/github/go-mod/go-version/Azure/kaito) -KAITO has been designed to simplify the workflow of launching AI inference services against popular large open sourced AI models, -such as Falcon or Llama, in a Kubernetes cluster. +Kaito is an operator that automates the AI/ML inference model deployment in a Kubernetes cluster. +The target models are popular large open sourced inference models such as [falcon](https://huggingface.co/tiiuae) and [llama 2](https://github.com/facebookresearch/llama). +Kaito has the following key differentiations compared to most of the mainstream model deployment methodologies built on top of virtual machine infrastructures. +- Manage large model files using container images. A http server is provided to perform inference calls using the model library. +- Avoid tuning deployment parameters to fit GPU hardware by providing preset configurations. +- Auto-provision GPU nodes based on model requirements. +- Host large model images in public Microsoft Container Registry(MCR) if the license allows. + +Using Kaito, the workflow of onboarding large AI inference models in Kubernetes is largely simplified. + + +## Architecture + +Kaito follows the classic Kubernetes Custom Resource Definition(CRD)/controller design pattern. User manages a `workspace` custom resource which describes the GPU requirements and the inference specification. Kaito controllers will automate the deployment by reconciling the `workspace` custom resource. +
+ +
+ +The above figure presents the Kaito architecture overview. Its major components consist of: +- **Workspace controller**: It reconciles the `workspace` custom resource, creates `machine` (explained below) custom resources to trigger node auto provisioning, and creates the inference workload (`deployment` or `statefulset`) based on the model preset configurations. +- **Node provisioner controller**: The controller's name is *gpu-provisioner* in [Kaito helm chart](charts/kaito/gpu-provisioner). It uses the `machine` CRD originated from [Karpenter](https://github.com/aws/karpenter-core) to interact with the workspace controller. It integrates with Azure Kubernetes Service(AKS) APIs to add new GPU nodes to the AKS cluster. +Note that the *gpu-provisioner* is not an open sourced component. It can be replaced by other controllers if they support Karpenter-core APIs. + + +--- ## Installation -The following guidence assumes **Azure Kubernetes Service(AKS)** is used to host the Kubernetes cluster . +The following guidance assumes **Azure Kubernetes Service(AKS)** is used to host the Kubernetes cluster . -### Enable Workload Identity and OIDC Issuer features -The `gpu-povisioner` component requires the [workload identity](https://learn.microsoft.com/en-us/azure/aks/workload-identity-overview?tabs=dotnet) feature to acquire the token to access the AKS managed cluster with proper permissions. +#### Enable Workload Identity and OIDC Issuer features +The *gpu-povisioner* controller requires the [workload identity](https://learn.microsoft.com/en-us/azure/aks/workload-identity-overview?tabs=dotnet) feature to acquire the access token to the AKS cluster. ```bash export RESOURCE_GROUP="myResourceGroup" @@ -18,8 +41,8 @@ export MY_CLUSTER="myCluster" az aks update -g $RESOURCE_GROUP -n $MY_CLUSTER --enable-oidc-issuer --enable-workload-identity --enable-managed-identity ``` -### Create an identity and assign permissions -The identity `kaitoprovisioner` is created for the `gpu-povisioner` controller. It is assigned Contributor role for the managed cluster resource to allow changing `$MY_CLUSTER` (e.g., provisioning new nodes in it). +#### Create an identity and assign permissions +The identity `kaitoprovisioner` is created for the *gpu-povisioner* controller. It is assigned Contributor role for the managed cluster resource to allow changing `$MY_CLUSTER` (e.g., provisioning new nodes in it). ```bash export SUBSCRIPTION="mySubscription" az identity create --name kaitoprovisioner -g $RESOURCE_GROUP @@ -29,7 +52,7 @@ az role assignment create --assignee $IDENTITY_PRINCIPAL_ID --scope /subscriptio ``` -### Install helm charts +#### Install helm charts Two charts will be installed in `$MY_CLUSTER`: `gpu-provisioner` chart and `workspace` chart. ```bash helm install workspace ./charts/kaito/workspace @@ -49,25 +72,27 @@ helm install gpu-provisioner ./charts/kaito/gpu-provisioner ``` -### Create federated credential -This allows `gpu-provisioner` controller to use `kaitoprovisioner` identity via an access token. +#### Create the federated credential +The federated identity credential between the managed identity `kaitoprovisioner` and the service account used by the *gpu-provisioner* controller is created. ```bash export AKS_OIDC_ISSUER=$(az aks show -n $MY_CLUSTER -g $RESOURCE_GROUP --subscription $SUBSCRIPTION --query "oidcIssuerProfile.issuerUrl" | tr -d '"') az identity federated-credential create --name kaito-federatedcredential --identity-name kaitoprovisioner -g $RESOURCE_GROUP --issuer $AKS_OIDC_ISSUER --subject system:serviceaccount:"gpu-provisioner:gpu-provisioner" --audience api://AzureADTokenExchange --subscription $SUBSCRIPTION ``` -Note that before doing this step, the `gpu-provisioner` controller pod will constantly fail with the following message in the log: +Then the *gpu-provisioner* can access the managed cluster using a trust token with the same permissions of the `kaitoprovisioner` identity. +Note that before finishing this step, the *gpu-provisioner* controller pod will constantly fail with the following message in the log: ``` panic: Configure azure client fails. Please ensure federatedcredential has been created for identity XXXX. ``` The pod will reach running state once the federated credential is created. -### Clean up +#### Clean up ```bash helm uninstall gpu-provisioner helm uninstall workspace ``` +--- ## Quick start After installing Kaito, one can try following commands to start a faclon-7b inference service. @@ -88,14 +113,14 @@ inference: $ kubectl apply -f examples/kaito_workspace_falcon_7b.yaml ``` -The workspace status can be tracked by running the following command. +The workspace status can be tracked by running the following command. When the WORKSPACEREADY column becomes `True`, the model has been deployed successfully. ``` $ kubectl get workspace workspace-falcon-7b NAME INSTANCE RESOURCEREADY INFERENCEREADY WORKSPACEREADY AGE workspace-falcon-7b Standard_NC12s_v3 True True True 10m ``` -Once the workspace is ready, one can find the inference service's cluster ip and use a temporal `curl` pod to test the service endpoint in cluster. +Next, one can find the inference service's cluster ip and use a temporal `curl` pod to test the service endpoint in the cluster. ``` $ kubectl get svc workspace-falcon-7b NAME TYPE CLUSTER-IP EXTERNAL-IP PORT(S) AGE @@ -105,7 +130,7 @@ $ kubectl run -it --rm --restart=Never curl --image=curlimages/curl sh ~ $ curl -X POST http:///chat -H "accept: application/json" -H "Content-Type: application/json" -d "{\"prompt\":\"YOUR QUESTION HERE\"}" ``` - +--- ## Contributing [Read more](docs/contributing/readme.md) diff --git a/api/v1alpha1/sku_config.go b/api/v1alpha1/sku_config.go new file mode 100644 index 000000000..68c9f0dc5 --- /dev/null +++ b/api/v1alpha1/sku_config.go @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package v1alpha1 + +import "strings" + +type GPUConfig struct { + SKU string + SupportedOS []string + GPUDriver string + GPUCount int + GPUMem int +} + +type PresetRequirements struct { + MinGPUCount int + MinMemoryPerGPU int // in GB + MinTotalMemory int // in GB +} + +var PresetRequirementsMap = map[string]PresetRequirements{ + "falcon-7b": {MinGPUCount: 1, MinMemoryPerGPU: 0, MinTotalMemory: 15}, + "falcon-7b-instruct": {MinGPUCount: 1, MinMemoryPerGPU: 0, MinTotalMemory: 15}, + "falcon-40b": {MinGPUCount: 2, MinMemoryPerGPU: 0, MinTotalMemory: 90}, + "falcon-40b-instruct": {MinGPUCount: 2, MinMemoryPerGPU: 0, MinTotalMemory: 90}, + + "llama-2-7b": {MinGPUCount: 1, MinMemoryPerGPU: 14, MinTotalMemory: 14}, + "llama-2-13b": {MinGPUCount: 2, MinMemoryPerGPU: 15, MinTotalMemory: 30}, + "llama-2-70b": {MinGPUCount: 8, MinMemoryPerGPU: 19, MinTotalMemory: 152}, + + "llama-2-7b-chat": {MinGPUCount: 1, MinMemoryPerGPU: 14, MinTotalMemory: 14}, + "llama-2-13b-chat": {MinGPUCount: 2, MinMemoryPerGPU: 15, MinTotalMemory: 30}, + "llama-2-70b-chat": {MinGPUCount: 8, MinMemoryPerGPU: 19, MinTotalMemory: 152}, +} + +// Helper function to check if a preset is valid +func isValidPreset(preset string) bool { + _, exists := PresetRequirementsMap[preset] + return exists +} + +func getSupportedSKUs() string { + skus := make([]string, 0, len(SupportedGPUConfigs)) + for sku := range SupportedGPUConfigs { + skus = append(skus, sku) + } + return strings.Join(skus, ", ") +} + +var SupportedGPUConfigs = map[string]GPUConfig{ + "standard_nc6": {SKU: "standard_nc6", GPUCount: 1, GPUMem: 12, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia470CudaDriver"}, + "standard_nc12": {SKU: "standard_nc12", GPUCount: 2, GPUMem: 24, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia470CudaDriver"}, + "standard_nc24": {SKU: "standard_nc24", GPUCount: 4, GPUMem: 48, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia470CudaDriver"}, + "standard_nc24r": {SKU: "standard_nc24r", GPUCount: 4, GPUMem: 48, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia470CudaDriver"}, + "standard_nv6": {SKU: "standard_nv6", GPUCount: 1, GPUMem: 8, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, + "standard_nv12": {SKU: "standard_nv12", GPUCount: 2, GPUMem: 16, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, + "standard_nv24": {SKU: "standard_nv24", GPUCount: 4, GPUMem: 32, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, + "standard_nv12s_v3": {SKU: "standard_nv12s_v3", GPUCount: 1, GPUMem: 8, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, + "standard_nv24s_v3": {SKU: "standard_nv24s_v3", GPUCount: 2, GPUMem: 16, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, + "standard_nv48s_v3": {SKU: "standard_nv48s_v3", GPUCount: 4, GPUMem: 32, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, + // "standard_nv24r": {SKU: "standard_nv24r", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, + "standard_nd6s": {SKU: "standard_nd6s", GPUCount: 1, GPUMem: 24, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + "standard_nd12s": {SKU: "standard_nd12s", GPUCount: 2, GPUMem: 48, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + "standard_nd24s": {SKU: "standard_nd24s", GPUCount: 4, GPUMem: 96, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + "standard_nd24rs": {SKU: "standard_nd24rs", GPUCount: 4, GPUMem: 96, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + "standard_nc6s_v2": {SKU: "standard_nc6s_v2", GPUCount: 1, GPUMem: 16, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + "standard_nc12s_v2": {SKU: "standard_nc12s_v2", GPUCount: 2, GPUMem: 32, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + "standard_nc24s_v2": {SKU: "standard_nc24s_v2", GPUCount: 4, GPUMem: 64, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + "standard_nc24rs_v2": {SKU: "standard_nc24rs_v2", GPUCount: 4, GPUMem: 64, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + "standard_nc6s_v3": {SKU: "standard_nc6s_v3", GPUCount: 1, GPUMem: 16, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + "standard_nc12s_v3": {SKU: "standard_nc12s_v3", GPUCount: 2, GPUMem: 32, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + "standard_nc24s_v3": {SKU: "standard_nc24s_v3", GPUCount: 4, GPUMem: 64, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + "standard_nc24rs_v3": {SKU: "standard_nc24rs_v3", GPUCount: 4, GPUMem: 64, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + // "standard_nd40s_v3": {SKU: "standard_nd40s_v3", GPUCount: x, GPUMem: x, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + "standard_nd40rs_v2": {SKU: "standard_nd40rs_v2", GPUCount: 8, GPUMem: 256, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + "standard_nc4as_t4_v3": {SKU: "standard_nc4as_t4_v3", GPUCount: 1, GPUMem: 16, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + "standard_nc8as_t4_v3": {SKU: "standard_nc8as_t4_v3", GPUCount: 1, GPUMem: 16, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + "standard_nc16as_t4_v3": {SKU: "standard_nc16as_t4_v3", GPUCount: 1, GPUMem: 16, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + "standard_nc64as_t4_v3": {SKU: "standard_nc64as_t4_v3", GPUCount: 4, GPUMem: 64, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + "standard_nd96asr_v4": {SKU: "standard_nd96asr_v4", GPUCount: 8, GPUMem: 320, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + // "standard_nd112asr_a100_v4": {SKU: "standard_nd112asr_a100_v4", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + // "standard_nd120asr_a100_v4": {SKU: "standard_nd120asr_a100_v4", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + "standard_nd96amsr_a100_v4": {SKU: "standard_nd96amsr_a100_v4", GPUCount: 8, GPUMem: 640, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + // "standard_nd112amsr_a100_v4": {SKU: "standard_nd112amsr_a100_v4", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + // "standard_nd120amsr_a100_v4": {SKU: "standard_nd120amsr_a100_v4", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + "standard_nc24ads_a100_v4": {SKU: "standard_nc24ads_a100_v4", GPUCount: 1, GPUMem: 80, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + "standard_nc48ads_a100_v4": {SKU: "standard_nc48ads_a100_v4", GPUCount: 2, GPUMem: 160, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + "standard_nc96ads_a100_v4": {SKU: "standard_nc96ads_a100_v4", GPUCount: 4, GPUMem: 320, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + // "standard_ncads_a100_v4": {SKU: "standard_ncads_a100_v4", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + /*GPU Mem based on A10-24 Spec - TODO: Need to confirm GPU Mem*/ + // "standard_nc8ads_a10_v4": {SKU: "standard_nc8ads_a10_v4", GPUCount: 1, GPUMem: 24, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, + // "standard_nc16ads_a10_v4": {SKU: "standard_nc16ads_a10_v4", GPUCount: 1, GPUMem: 24, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, + // "standard_nc32ads_a10_v4": {SKU: "standard_nc32ads_a10_v4", GPUCount: 2, GPUMem: 48, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, + /* SKUs with GPU Partition are treated as 1 GPU - https://learn.microsoft.com/en-us/azure/virtual-machines/nva10v5-series*/ + "standard_nv6ads_a10_v5": {SKU: "standard_nv6ads_a10_v5", GPUCount: 1, GPUMem: 4, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, + "standard_nv12ads_a10_v5": {SKU: "standard_nv12ads_a10_v5", GPUCount: 1, GPUMem: 8, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, + "standard_nv18ads_a10_v5": {SKU: "standard_nv18ads_a10_v5", GPUCount: 1, GPUMem: 12, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, + "standard_nv36ads_a10_v5": {SKU: "standard_nv36ads_a10_v5", GPUCount: 1, GPUMem: 24, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, + "standard_nv36adms_a10_v5": {SKU: "standard_nv36adms_a10_v5", GPUCount: 1, GPUMem: 24, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, + "standard_nv72ads_a10_v5": {SKU: "standard_nv72ads_a10_v5", GPUCount: 2, GPUMem: 48, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, + // "standard_nd96ams_v4": {SKU: "standard_nd96ams_v4", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, + // "standard_nd96ams_a100_v4": {SKU: "standard_nd96ams_a100_v4", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, +} diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 61e091b72..4496d3311 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "reflect" + "strings" admissionregistrationv1 "k8s.io/api/admissionregistration/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -11,6 +12,11 @@ import ( "knative.dev/pkg/apis" ) +const ( + N_SERIES_PREFIX = "standard_n" + D_SERIES_PREFIX = "standard_d" +) + func (w *Workspace) SupportedVerbs() []admissionregistrationv1.OperationType { return []admissionregistrationv1.OperationType{ admissionregistrationv1.Create, @@ -32,6 +38,7 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { - The preset name needs to be supported enum. */ errs = errs.Also( + w.Resource.validateCreate(w.Inference).ViaField("resource"), w.Inference.validateCreate().ViaField("inference"), ) } else { @@ -45,6 +52,46 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { return errs } +func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.FieldError) { + presetName := strings.ToLower(string(inference.Preset.Name)) + instanceType := strings.ToLower(string(r.InstanceType)) + + // Check if instancetype exists in our SKUs map + if skuConfig, exists := SupportedGPUConfigs[instanceType]; exists { + // Validate GPU count for given SKU + if presetReq, ok := PresetRequirementsMap[presetName]; ok { + machineCount := *r.Count + totalNumGPUs := machineCount * skuConfig.GPUCount + totalGPUMem := machineCount * skuConfig.GPUMem * skuConfig.GPUCount + + // Separate the checks for specific error messages + if totalNumGPUs < presetReq.MinGPUCount { + errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Insufficient number of GPUs: Instance type %s provides %d, but preset %s requires at least %d", instanceType, totalNumGPUs, presetName, presetReq.MinGPUCount), "instanceType")) + } + if skuConfig.GPUMem < presetReq.MinMemoryPerGPU { + errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Insufficient GPU memory: Instance type %s provides %d per GPU, but preset %s requires at least %d per GPU", instanceType, skuConfig.GPUMem, presetName, presetReq.MinMemoryPerGPU), "instanceType")) + } + if totalGPUMem < presetReq.MinTotalMemory { + errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Insufficient total GPU memory: Instance type %s has a total of %d, but preset %s requires at least %d", instanceType, totalGPUMem, presetName, presetReq.MinTotalMemory), "instanceType")) + } + } else { + errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported preset name %s", presetName), "presetName")) + } + } else { + // Check for other instancetypes pattern matches + if !strings.HasPrefix(instanceType, N_SERIES_PREFIX) && !strings.HasPrefix(instanceType, D_SERIES_PREFIX) { + errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported instance type %s. Supported SKUs: %s", instanceType, getSupportedSKUs()), "instanceType")) + } + } + + // Validate labelSelector + if _, err := metav1.LabelSelectorAsMap(r.LabelSelector); err != nil { + errs = errs.Also(apis.ErrInvalidValue(err.Error(), "labelSelector")) + } + + return errs +} + func (r *ResourceSpec) validateUpdate(old *ResourceSpec) (errs *apis.FieldError) { // We disable changing node count for now. if r.Count != nil && old.Count != nil && *r.Count != *old.Count { @@ -66,6 +113,11 @@ func (r *ResourceSpec) validateUpdate(old *ResourceSpec) (errs *apis.FieldError) } func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) { + presetName := strings.ToLower(string(i.Preset.Name)) + // Validate preset name + if !isValidPreset(presetName) { + errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported preset name %s", presetName), "presetName")) + } if i.Preset != nil && i.Template != nil { errs = errs.Also(apis.ErrGeneric("preset and template cannot be set at the same time")) } diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go new file mode 100644 index 000000000..39eda44a7 --- /dev/null +++ b/api/v1alpha1/workspace_validation_test.go @@ -0,0 +1,478 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package v1alpha1 + +import ( + "reflect" + "sort" + "strings" + "testing" + + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func pointerToInt(i int) *int { + return &i +} + +func TestResourceSpecValidateCreate(t *testing.T) { + tests := []struct { + name string + resourceSpec *ResourceSpec + inferenceSpec *InferenceSpec + errContent string // Content expect error to include, if any + expectErrs bool + }{ + { + name: "Insufficient total GPU memory", + resourceSpec: &ResourceSpec{ + InstanceType: "standard_nc6", + Count: pointerToInt(1), + }, + inferenceSpec: &InferenceSpec{ + Preset: &PresetSpec{ + PresetMeta: PresetMeta{ + Name: ModelName("falcon-7b"), + }, + }, + }, + errContent: "Insufficient total GPU memory", + expectErrs: true, + }, + + { + name: "Insufficient number of GPUs", + resourceSpec: &ResourceSpec{ + InstanceType: "standard_nc24ads_a100_v4", + Count: pointerToInt(1), + }, + inferenceSpec: &InferenceSpec{ + Preset: &PresetSpec{ + PresetMeta: PresetMeta{ + Name: ModelName("llama-2-13b-chat"), + }, + }, + }, + errContent: "Insufficient number of GPUs", + expectErrs: true, + }, + + { + name: "Invalid SKU", + resourceSpec: &ResourceSpec{ + InstanceType: "standard_invalid_sku", + Count: pointerToInt(1), + }, + inferenceSpec: &InferenceSpec{ + Preset: &PresetSpec{ + PresetMeta: PresetMeta{ + Name: ModelName("llama-2-70b"), + }, + }, + }, + errContent: "Unsupported instance", + expectErrs: true, + }, + + { + name: "Invalid Preset", + resourceSpec: &ResourceSpec{ + InstanceType: "standard_nv12s_v3", + Count: pointerToInt(1), + }, + inferenceSpec: &InferenceSpec{ + Preset: &PresetSpec{ + PresetMeta: PresetMeta{ + Name: ModelName("invalid-preset"), + }, + }, + }, + errContent: "Unsupported preset", + expectErrs: true, + }, + + { + name: "Invalid SKU", + resourceSpec: &ResourceSpec{ + InstanceType: "standard_invalid_sku", + Count: pointerToInt(1), + }, + inferenceSpec: &InferenceSpec{ + Preset: &PresetSpec{ + PresetMeta: PresetMeta{ + Name: ModelName("llama-2-70b"), + }, + }, + }, + errContent: "Unsupported instance", + expectErrs: true, + }, + + { + name: "N-Prefix SKU", + resourceSpec: &ResourceSpec{ + InstanceType: "standard_nsku", + Count: pointerToInt(1), + }, + inferenceSpec: &InferenceSpec{ + Preset: &PresetSpec{ + PresetMeta: PresetMeta{ + Name: ModelName("llama-2-7b"), + }, + }, + }, + errContent: "", + expectErrs: false, + }, + + { + name: "D-Prefix SKU", + resourceSpec: &ResourceSpec{ + InstanceType: "standard_dsku", + Count: pointerToInt(1), + }, + inferenceSpec: &InferenceSpec{ + Preset: &PresetSpec{ + PresetMeta: PresetMeta{ + Name: ModelName("llama-2-7b"), + }, + }, + }, + errContent: "", + expectErrs: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + errs := tc.resourceSpec.validateCreate(*tc.inferenceSpec) + hasErrs := errs != nil + if hasErrs != tc.expectErrs { + t.Errorf("validateCreate() errors = %v, expectErrs %v", errs, tc.expectErrs) + } + + // If there is an error and errContent is not empty, check that the error contains the expected content. + if hasErrs && tc.errContent != "" { + errMsg := errs.Error() + if !strings.Contains(errMsg, tc.errContent) { + t.Errorf("validateCreate() error message = %v, expected to contain = %v", errMsg, tc.errContent) + } + } + }) + } +} + +func TestResourceSpecValidateUpdate(t *testing.T) { + + tests := []struct { + name string + newResource *ResourceSpec + oldResource *ResourceSpec + errContent string // Content expected error to include, if any + expectErrs bool + }{ + { + name: "Immutable Count", + newResource: &ResourceSpec{ + Count: pointerToInt(10), + }, + oldResource: &ResourceSpec{ + Count: pointerToInt(5), + }, + errContent: "field is immutable", + expectErrs: true, + }, + { + name: "Immutable InstanceType", + newResource: &ResourceSpec{ + InstanceType: "new_type", + }, + oldResource: &ResourceSpec{ + InstanceType: "old_type", + }, + errContent: "field is immutable", + expectErrs: true, + }, + { + name: "Immutable LabelSelector", + newResource: &ResourceSpec{ + LabelSelector: &metav1.LabelSelector{MatchLabels: map[string]string{"key1": "value1"}}, + }, + oldResource: &ResourceSpec{ + LabelSelector: &metav1.LabelSelector{MatchLabels: map[string]string{"key2": "value2"}}, + }, + errContent: "field is immutable", + expectErrs: true, + }, + { + name: "Valid Update", + newResource: &ResourceSpec{ + Count: pointerToInt(5), + InstanceType: "same_type", + LabelSelector: &metav1.LabelSelector{MatchLabels: map[string]string{"key": "value"}}, + }, + oldResource: &ResourceSpec{ + Count: pointerToInt(5), + InstanceType: "same_type", + LabelSelector: &metav1.LabelSelector{MatchLabels: map[string]string{"key": "value"}}, + }, + errContent: "", + expectErrs: false, + }, + } + + // Run the tests + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + errs := tc.newResource.validateUpdate(tc.oldResource) + hasErrs := errs != nil + if hasErrs != tc.expectErrs { + t.Errorf("validateUpdate() errors = %v, expectErrs %v", errs, tc.expectErrs) + } + + // If there is an error and errContent is not empty, check that the error contains the expected content. + if hasErrs && tc.errContent != "" { + errMsg := errs.Error() + if !strings.Contains(errMsg, tc.errContent) { + t.Errorf("validateUpdate() error message = %v, expected to contain = %v", errMsg, tc.errContent) + } + } + }) + } +} + +func TestInferenceSpecValidateCreate(t *testing.T) { + tests := []struct { + name string + inferenceSpec *InferenceSpec + errContent string // Content expected error to include, if any + expectErrs bool + }{ + { + name: "Invalid Preset Name", + inferenceSpec: &InferenceSpec{ + Preset: &PresetSpec{ + PresetMeta: PresetMeta{ + Name: ModelName("Invalid-Preset-Name"), + }, + }, + }, + errContent: "Unsupported preset name", + expectErrs: true, + }, + { + name: "Preset and Template Set", + inferenceSpec: &InferenceSpec{ + Preset: &PresetSpec{ + PresetMeta: PresetMeta{ + Name: ModelName("falcon-7b"), + }, + }, + Template: &v1.PodTemplateSpec{}, // Assuming a non-nil TemplateSpec implies it's set + }, + errContent: "preset and template cannot be set at the same time", + expectErrs: true, + }, + { + name: "Private Access Without Image", + inferenceSpec: &InferenceSpec{ + Preset: &PresetSpec{ + PresetMeta: PresetMeta{ + Name: ModelName("llama-2-7b"), + AccessMode: "private", + }, + PresetOptions: PresetOptions{}, + }, + }, + errContent: "When AccessMode is private, an image must be provided", + expectErrs: true, + }, + { + name: "Valid Preset", + inferenceSpec: &InferenceSpec{ + Preset: &PresetSpec{ + PresetMeta: PresetMeta{ + Name: ModelName("falcon-7b"), + AccessMode: "public", + }, + }, + }, + errContent: "", + expectErrs: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + errs := tc.inferenceSpec.validateCreate() + hasErrs := errs != nil + if hasErrs != tc.expectErrs { + t.Errorf("validateCreate() errors = %v, expectErrs %v", errs, tc.expectErrs) + } + + // If there is an error and errContent is not empty, check that the error contains the expected content. + if hasErrs && tc.errContent != "" { + errMsg := errs.Error() + if !strings.Contains(errMsg, tc.errContent) { + t.Errorf("validateCreate() error message = %v, expected to contain = %v", errMsg, tc.errContent) + } + } + }) + } +} + +func TestInferenceSpecValidateUpdate(t *testing.T) { + tests := []struct { + name string + newInference *InferenceSpec + oldInference *InferenceSpec + errContent string // Content expected error to include, if any + expectErrs bool + }{ + { + name: "Preset Immutable", + newInference: &InferenceSpec{ + Preset: &PresetSpec{ + PresetMeta: PresetMeta{ + Name: ModelName("new-preset"), + }, + }, + }, + oldInference: &InferenceSpec{ + Preset: &PresetSpec{ + PresetMeta: PresetMeta{ + Name: ModelName("old-preset"), + }, + }, + }, + errContent: "field is immutable", + expectErrs: true, + }, + { + name: "Template Unset", + newInference: &InferenceSpec{ + Template: nil, + }, + oldInference: &InferenceSpec{ + Template: &v1.PodTemplateSpec{}, + }, + errContent: "field cannot be unset/set if it was set/unset", + expectErrs: true, + }, + { + name: "Template Set", + newInference: &InferenceSpec{ + Template: &v1.PodTemplateSpec{}, + }, + oldInference: &InferenceSpec{ + Template: nil, + }, + errContent: "field cannot be unset/set if it was set/unset", + expectErrs: true, + }, + { + name: "Valid Update", + newInference: &InferenceSpec{ + Template: &v1.PodTemplateSpec{}, + }, + oldInference: &InferenceSpec{ + Template: &v1.PodTemplateSpec{}, + }, + errContent: "", + expectErrs: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + errs := tc.newInference.validateUpdate(tc.oldInference) + hasErrs := errs != nil + if hasErrs != tc.expectErrs { + t.Errorf("validateUpdate() errors = %v, expectErrs %v", errs, tc.expectErrs) + } + + // If there is an error and errContent is not empty, check that the error contains the expected content. + if hasErrs && tc.errContent != "" { + errMsg := errs.Error() + if !strings.Contains(errMsg, tc.errContent) { + t.Errorf("validateUpdate() error message = %v, expected to contain = %v", errMsg, tc.errContent) + } + } + }) + } +} + +func TestGetSupportedSKUs(t *testing.T) { + tests := []struct { + name string + gpuConfigs map[string]GPUConfig + expectedResult []string // changed to a slice for deterministic ordering + }{ + { + name: "no SKUs supported", + gpuConfigs: map[string]GPUConfig{}, + expectedResult: []string{""}, + }, + { + name: "one SKU supported", + gpuConfigs: map[string]GPUConfig{ + "standard_nc6": {SKU: "standard_nc6"}, + }, + expectedResult: []string{"standard_nc6"}, + }, + { + name: "multiple SKUs supported", + gpuConfigs: map[string]GPUConfig{ + "standard_nc6": {SKU: "standard_nc6"}, + "standard_nc12": {SKU: "standard_nc12"}, + }, + expectedResult: []string{"standard_nc6", "standard_nc12"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + SupportedGPUConfigs = tc.gpuConfigs + + resultSlice := strings.Split(getSupportedSKUs(), ", ") + sort.Strings(resultSlice) + + // Sort the expectedResult for comparison + expectedResultSlice := tc.expectedResult + sort.Strings(expectedResultSlice) + + if !reflect.DeepEqual(resultSlice, expectedResultSlice) { + t.Errorf("getSupportedSKUs() = %v, want %v", resultSlice, expectedResultSlice) + } + }) + } +} + +func TestIsValidPreset(t *testing.T) { + tests := []struct { + name string + preset string + expectValid bool + }{ + { + name: "valid preset", + preset: "falcon-7b", + expectValid: true, + }, + { + name: "invalid preset", + preset: "nonexistent-preset", + expectValid: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if valid := isValidPreset(test.preset); valid != test.expectValid { + t.Errorf("isValidPreset(%s) = %v, want %v", test.preset, valid, test.expectValid) + } + }) + } +} diff --git a/api/v1alpha1/zz_generated.deepcopy.go b/api/v1alpha1/zz_generated.deepcopy.go index bfdc825b4..ef60d7362 100644 --- a/api/v1alpha1/zz_generated.deepcopy.go +++ b/api/v1alpha1/zz_generated.deepcopy.go @@ -14,6 +14,26 @@ import ( runtime "k8s.io/apimachinery/pkg/runtime" ) +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *GPUConfig) DeepCopyInto(out *GPUConfig) { + *out = *in + if in.SupportedOS != nil { + in, out := &in.SupportedOS, &out.SupportedOS + *out = make([]string, len(*in)) + copy(*out, *in) + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new GPUConfig. +func (in *GPUConfig) DeepCopy() *GPUConfig { + if in == nil { + return nil + } + out := new(GPUConfig) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *InferenceSpec) DeepCopyInto(out *InferenceSpec) { *out = *in @@ -74,6 +94,21 @@ func (in *PresetOptions) DeepCopy() *PresetOptions { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *PresetRequirements) DeepCopyInto(out *PresetRequirements) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new PresetRequirements. +func (in *PresetRequirements) DeepCopy() *PresetRequirements { + if in == nil { + return nil + } + out := new(PresetRequirements) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *PresetSpec) DeepCopyInto(out *PresetSpec) { *out = *in diff --git a/docs/img/arch.png b/docs/img/arch.png new file mode 100644 index 000000000..94c399faf Binary files /dev/null and b/docs/img/arch.png differ