diff --git a/README.md b/README.md
index 32b98b10f..fdfe3131c 100644
--- a/README.md
+++ b/README.md
@@ -1,16 +1,39 @@
-# Kubernetes AI Toolchain Operator(KAITO)
+# Kubernetes AI Toolchain Operator (Kaito)
[![Go Report Card](https://goreportcard.com/badge/github.com/Azure/kaito)](https://goreportcard.com/report/github.com/Azure/kaito)
![GitHub go.mod Go version](https://img.shields.io/github/go-mod/go-version/Azure/kaito)
-KAITO has been designed to simplify the workflow of launching AI inference services against popular large open sourced AI models,
-such as Falcon or Llama, in a Kubernetes cluster.
+Kaito is an operator that automates the AI/ML inference model deployment in a Kubernetes cluster.
+The target models are popular large open sourced inference models such as [falcon](https://huggingface.co/tiiuae) and [llama 2](https://github.com/facebookresearch/llama).
+Kaito has the following key differentiations compared to most of the mainstream model deployment methodologies built on top of virtual machine infrastructures.
+- Manage large model files using container images. A http server is provided to perform inference calls using the model library.
+- Avoid tuning deployment parameters to fit GPU hardware by providing preset configurations.
+- Auto-provision GPU nodes based on model requirements.
+- Host large model images in public Microsoft Container Registry(MCR) if the license allows.
+
+Using Kaito, the workflow of onboarding large AI inference models in Kubernetes is largely simplified.
+
+
+## Architecture
+
+Kaito follows the classic Kubernetes Custom Resource Definition(CRD)/controller design pattern. User manages a `workspace` custom resource which describes the GPU requirements and the inference specification. Kaito controllers will automate the deployment by reconciling the `workspace` custom resource.
+
+
+
+
+The above figure presents the Kaito architecture overview. Its major components consist of:
+- **Workspace controller**: It reconciles the `workspace` custom resource, creates `machine` (explained below) custom resources to trigger node auto provisioning, and creates the inference workload (`deployment` or `statefulset`) based on the model preset configurations.
+- **Node provisioner controller**: The controller's name is *gpu-provisioner* in [Kaito helm chart](charts/kaito/gpu-provisioner). It uses the `machine` CRD originated from [Karpenter](https://github.com/aws/karpenter-core) to interact with the workspace controller. It integrates with Azure Kubernetes Service(AKS) APIs to add new GPU nodes to the AKS cluster.
+Note that the *gpu-provisioner* is not an open sourced component. It can be replaced by other controllers if they support Karpenter-core APIs.
+
+
+---
## Installation
-The following guidence assumes **Azure Kubernetes Service(AKS)** is used to host the Kubernetes cluster .
+The following guidance assumes **Azure Kubernetes Service(AKS)** is used to host the Kubernetes cluster .
-### Enable Workload Identity and OIDC Issuer features
-The `gpu-povisioner` component requires the [workload identity](https://learn.microsoft.com/en-us/azure/aks/workload-identity-overview?tabs=dotnet) feature to acquire the token to access the AKS managed cluster with proper permissions.
+#### Enable Workload Identity and OIDC Issuer features
+The *gpu-povisioner* controller requires the [workload identity](https://learn.microsoft.com/en-us/azure/aks/workload-identity-overview?tabs=dotnet) feature to acquire the access token to the AKS cluster.
```bash
export RESOURCE_GROUP="myResourceGroup"
@@ -18,8 +41,8 @@ export MY_CLUSTER="myCluster"
az aks update -g $RESOURCE_GROUP -n $MY_CLUSTER --enable-oidc-issuer --enable-workload-identity --enable-managed-identity
```
-### Create an identity and assign permissions
-The identity `kaitoprovisioner` is created for the `gpu-povisioner` controller. It is assigned Contributor role for the managed cluster resource to allow changing `$MY_CLUSTER` (e.g., provisioning new nodes in it).
+#### Create an identity and assign permissions
+The identity `kaitoprovisioner` is created for the *gpu-povisioner* controller. It is assigned Contributor role for the managed cluster resource to allow changing `$MY_CLUSTER` (e.g., provisioning new nodes in it).
```bash
export SUBSCRIPTION="mySubscription"
az identity create --name kaitoprovisioner -g $RESOURCE_GROUP
@@ -29,7 +52,7 @@ az role assignment create --assignee $IDENTITY_PRINCIPAL_ID --scope /subscriptio
```
-### Install helm charts
+#### Install helm charts
Two charts will be installed in `$MY_CLUSTER`: `gpu-provisioner` chart and `workspace` chart.
```bash
helm install workspace ./charts/kaito/workspace
@@ -49,25 +72,27 @@ helm install gpu-provisioner ./charts/kaito/gpu-provisioner
```
-### Create federated credential
-This allows `gpu-provisioner` controller to use `kaitoprovisioner` identity via an access token.
+#### Create the federated credential
+The federated identity credential between the managed identity `kaitoprovisioner` and the service account used by the *gpu-provisioner* controller is created.
```bash
export AKS_OIDC_ISSUER=$(az aks show -n $MY_CLUSTER -g $RESOURCE_GROUP --subscription $SUBSCRIPTION --query "oidcIssuerProfile.issuerUrl" | tr -d '"')
az identity federated-credential create --name kaito-federatedcredential --identity-name kaitoprovisioner -g $RESOURCE_GROUP --issuer $AKS_OIDC_ISSUER --subject system:serviceaccount:"gpu-provisioner:gpu-provisioner" --audience api://AzureADTokenExchange --subscription $SUBSCRIPTION
```
-Note that before doing this step, the `gpu-provisioner` controller pod will constantly fail with the following message in the log:
+Then the *gpu-provisioner* can access the managed cluster using a trust token with the same permissions of the `kaitoprovisioner` identity.
+Note that before finishing this step, the *gpu-provisioner* controller pod will constantly fail with the following message in the log:
```
panic: Configure azure client fails. Please ensure federatedcredential has been created for identity XXXX.
```
The pod will reach running state once the federated credential is created.
-### Clean up
+#### Clean up
```bash
helm uninstall gpu-provisioner
helm uninstall workspace
```
+---
## Quick start
After installing Kaito, one can try following commands to start a faclon-7b inference service.
@@ -88,14 +113,14 @@ inference:
$ kubectl apply -f examples/kaito_workspace_falcon_7b.yaml
```
-The workspace status can be tracked by running the following command.
+The workspace status can be tracked by running the following command. When the WORKSPACEREADY column becomes `True`, the model has been deployed successfully.
```
$ kubectl get workspace workspace-falcon-7b
NAME INSTANCE RESOURCEREADY INFERENCEREADY WORKSPACEREADY AGE
workspace-falcon-7b Standard_NC12s_v3 True True True 10m
```
-Once the workspace is ready, one can find the inference service's cluster ip and use a temporal `curl` pod to test the service endpoint in cluster.
+Next, one can find the inference service's cluster ip and use a temporal `curl` pod to test the service endpoint in the cluster.
```
$ kubectl get svc workspace-falcon-7b
NAME TYPE CLUSTER-IP EXTERNAL-IP PORT(S) AGE
@@ -105,7 +130,7 @@ $ kubectl run -it --rm --restart=Never curl --image=curlimages/curl sh
~ $ curl -X POST http:///chat -H "accept: application/json" -H "Content-Type: application/json" -d "{\"prompt\":\"YOUR QUESTION HERE\"}"
```
-
+---
## Contributing
[Read more](docs/contributing/readme.md)
diff --git a/api/v1alpha1/sku_config.go b/api/v1alpha1/sku_config.go
new file mode 100644
index 000000000..68c9f0dc5
--- /dev/null
+++ b/api/v1alpha1/sku_config.go
@@ -0,0 +1,104 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT license.
+
+package v1alpha1
+
+import "strings"
+
+type GPUConfig struct {
+ SKU string
+ SupportedOS []string
+ GPUDriver string
+ GPUCount int
+ GPUMem int
+}
+
+type PresetRequirements struct {
+ MinGPUCount int
+ MinMemoryPerGPU int // in GB
+ MinTotalMemory int // in GB
+}
+
+var PresetRequirementsMap = map[string]PresetRequirements{
+ "falcon-7b": {MinGPUCount: 1, MinMemoryPerGPU: 0, MinTotalMemory: 15},
+ "falcon-7b-instruct": {MinGPUCount: 1, MinMemoryPerGPU: 0, MinTotalMemory: 15},
+ "falcon-40b": {MinGPUCount: 2, MinMemoryPerGPU: 0, MinTotalMemory: 90},
+ "falcon-40b-instruct": {MinGPUCount: 2, MinMemoryPerGPU: 0, MinTotalMemory: 90},
+
+ "llama-2-7b": {MinGPUCount: 1, MinMemoryPerGPU: 14, MinTotalMemory: 14},
+ "llama-2-13b": {MinGPUCount: 2, MinMemoryPerGPU: 15, MinTotalMemory: 30},
+ "llama-2-70b": {MinGPUCount: 8, MinMemoryPerGPU: 19, MinTotalMemory: 152},
+
+ "llama-2-7b-chat": {MinGPUCount: 1, MinMemoryPerGPU: 14, MinTotalMemory: 14},
+ "llama-2-13b-chat": {MinGPUCount: 2, MinMemoryPerGPU: 15, MinTotalMemory: 30},
+ "llama-2-70b-chat": {MinGPUCount: 8, MinMemoryPerGPU: 19, MinTotalMemory: 152},
+}
+
+// Helper function to check if a preset is valid
+func isValidPreset(preset string) bool {
+ _, exists := PresetRequirementsMap[preset]
+ return exists
+}
+
+func getSupportedSKUs() string {
+ skus := make([]string, 0, len(SupportedGPUConfigs))
+ for sku := range SupportedGPUConfigs {
+ skus = append(skus, sku)
+ }
+ return strings.Join(skus, ", ")
+}
+
+var SupportedGPUConfigs = map[string]GPUConfig{
+ "standard_nc6": {SKU: "standard_nc6", GPUCount: 1, GPUMem: 12, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia470CudaDriver"},
+ "standard_nc12": {SKU: "standard_nc12", GPUCount: 2, GPUMem: 24, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia470CudaDriver"},
+ "standard_nc24": {SKU: "standard_nc24", GPUCount: 4, GPUMem: 48, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia470CudaDriver"},
+ "standard_nc24r": {SKU: "standard_nc24r", GPUCount: 4, GPUMem: 48, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia470CudaDriver"},
+ "standard_nv6": {SKU: "standard_nv6", GPUCount: 1, GPUMem: 8, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"},
+ "standard_nv12": {SKU: "standard_nv12", GPUCount: 2, GPUMem: 16, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"},
+ "standard_nv24": {SKU: "standard_nv24", GPUCount: 4, GPUMem: 32, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"},
+ "standard_nv12s_v3": {SKU: "standard_nv12s_v3", GPUCount: 1, GPUMem: 8, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"},
+ "standard_nv24s_v3": {SKU: "standard_nv24s_v3", GPUCount: 2, GPUMem: 16, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"},
+ "standard_nv48s_v3": {SKU: "standard_nv48s_v3", GPUCount: 4, GPUMem: 32, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"},
+ // "standard_nv24r": {SKU: "standard_nv24r", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"},
+ "standard_nd6s": {SKU: "standard_nd6s", GPUCount: 1, GPUMem: 24, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ "standard_nd12s": {SKU: "standard_nd12s", GPUCount: 2, GPUMem: 48, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ "standard_nd24s": {SKU: "standard_nd24s", GPUCount: 4, GPUMem: 96, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ "standard_nd24rs": {SKU: "standard_nd24rs", GPUCount: 4, GPUMem: 96, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ "standard_nc6s_v2": {SKU: "standard_nc6s_v2", GPUCount: 1, GPUMem: 16, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ "standard_nc12s_v2": {SKU: "standard_nc12s_v2", GPUCount: 2, GPUMem: 32, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ "standard_nc24s_v2": {SKU: "standard_nc24s_v2", GPUCount: 4, GPUMem: 64, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ "standard_nc24rs_v2": {SKU: "standard_nc24rs_v2", GPUCount: 4, GPUMem: 64, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ "standard_nc6s_v3": {SKU: "standard_nc6s_v3", GPUCount: 1, GPUMem: 16, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ "standard_nc12s_v3": {SKU: "standard_nc12s_v3", GPUCount: 2, GPUMem: 32, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ "standard_nc24s_v3": {SKU: "standard_nc24s_v3", GPUCount: 4, GPUMem: 64, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ "standard_nc24rs_v3": {SKU: "standard_nc24rs_v3", GPUCount: 4, GPUMem: 64, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ // "standard_nd40s_v3": {SKU: "standard_nd40s_v3", GPUCount: x, GPUMem: x, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ "standard_nd40rs_v2": {SKU: "standard_nd40rs_v2", GPUCount: 8, GPUMem: 256, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ "standard_nc4as_t4_v3": {SKU: "standard_nc4as_t4_v3", GPUCount: 1, GPUMem: 16, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ "standard_nc8as_t4_v3": {SKU: "standard_nc8as_t4_v3", GPUCount: 1, GPUMem: 16, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ "standard_nc16as_t4_v3": {SKU: "standard_nc16as_t4_v3", GPUCount: 1, GPUMem: 16, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ "standard_nc64as_t4_v3": {SKU: "standard_nc64as_t4_v3", GPUCount: 4, GPUMem: 64, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ "standard_nd96asr_v4": {SKU: "standard_nd96asr_v4", GPUCount: 8, GPUMem: 320, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ // "standard_nd112asr_a100_v4": {SKU: "standard_nd112asr_a100_v4", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ // "standard_nd120asr_a100_v4": {SKU: "standard_nd120asr_a100_v4", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ "standard_nd96amsr_a100_v4": {SKU: "standard_nd96amsr_a100_v4", GPUCount: 8, GPUMem: 640, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ // "standard_nd112amsr_a100_v4": {SKU: "standard_nd112amsr_a100_v4", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ // "standard_nd120amsr_a100_v4": {SKU: "standard_nd120amsr_a100_v4", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ "standard_nc24ads_a100_v4": {SKU: "standard_nc24ads_a100_v4", GPUCount: 1, GPUMem: 80, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ "standard_nc48ads_a100_v4": {SKU: "standard_nc48ads_a100_v4", GPUCount: 2, GPUMem: 160, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ "standard_nc96ads_a100_v4": {SKU: "standard_nc96ads_a100_v4", GPUCount: 4, GPUMem: 320, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ // "standard_ncads_a100_v4": {SKU: "standard_ncads_a100_v4", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ /*GPU Mem based on A10-24 Spec - TODO: Need to confirm GPU Mem*/
+ // "standard_nc8ads_a10_v4": {SKU: "standard_nc8ads_a10_v4", GPUCount: 1, GPUMem: 24, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"},
+ // "standard_nc16ads_a10_v4": {SKU: "standard_nc16ads_a10_v4", GPUCount: 1, GPUMem: 24, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"},
+ // "standard_nc32ads_a10_v4": {SKU: "standard_nc32ads_a10_v4", GPUCount: 2, GPUMem: 48, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"},
+ /* SKUs with GPU Partition are treated as 1 GPU - https://learn.microsoft.com/en-us/azure/virtual-machines/nva10v5-series*/
+ "standard_nv6ads_a10_v5": {SKU: "standard_nv6ads_a10_v5", GPUCount: 1, GPUMem: 4, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"},
+ "standard_nv12ads_a10_v5": {SKU: "standard_nv12ads_a10_v5", GPUCount: 1, GPUMem: 8, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"},
+ "standard_nv18ads_a10_v5": {SKU: "standard_nv18ads_a10_v5", GPUCount: 1, GPUMem: 12, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"},
+ "standard_nv36ads_a10_v5": {SKU: "standard_nv36ads_a10_v5", GPUCount: 1, GPUMem: 24, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"},
+ "standard_nv36adms_a10_v5": {SKU: "standard_nv36adms_a10_v5", GPUCount: 1, GPUMem: 24, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"},
+ "standard_nv72ads_a10_v5": {SKU: "standard_nv72ads_a10_v5", GPUCount: 2, GPUMem: 48, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"},
+ // "standard_nd96ams_v4": {SKU: "standard_nd96ams_v4", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+ // "standard_nd96ams_a100_v4": {SKU: "standard_nd96ams_a100_v4", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"},
+}
diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go
index 61e091b72..4496d3311 100644
--- a/api/v1alpha1/workspace_validation.go
+++ b/api/v1alpha1/workspace_validation.go
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"reflect"
+ "strings"
admissionregistrationv1 "k8s.io/api/admissionregistration/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@@ -11,6 +12,11 @@ import (
"knative.dev/pkg/apis"
)
+const (
+ N_SERIES_PREFIX = "standard_n"
+ D_SERIES_PREFIX = "standard_d"
+)
+
func (w *Workspace) SupportedVerbs() []admissionregistrationv1.OperationType {
return []admissionregistrationv1.OperationType{
admissionregistrationv1.Create,
@@ -32,6 +38,7 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) {
- The preset name needs to be supported enum.
*/
errs = errs.Also(
+ w.Resource.validateCreate(w.Inference).ViaField("resource"),
w.Inference.validateCreate().ViaField("inference"),
)
} else {
@@ -45,6 +52,46 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) {
return errs
}
+func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.FieldError) {
+ presetName := strings.ToLower(string(inference.Preset.Name))
+ instanceType := strings.ToLower(string(r.InstanceType))
+
+ // Check if instancetype exists in our SKUs map
+ if skuConfig, exists := SupportedGPUConfigs[instanceType]; exists {
+ // Validate GPU count for given SKU
+ if presetReq, ok := PresetRequirementsMap[presetName]; ok {
+ machineCount := *r.Count
+ totalNumGPUs := machineCount * skuConfig.GPUCount
+ totalGPUMem := machineCount * skuConfig.GPUMem * skuConfig.GPUCount
+
+ // Separate the checks for specific error messages
+ if totalNumGPUs < presetReq.MinGPUCount {
+ errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Insufficient number of GPUs: Instance type %s provides %d, but preset %s requires at least %d", instanceType, totalNumGPUs, presetName, presetReq.MinGPUCount), "instanceType"))
+ }
+ if skuConfig.GPUMem < presetReq.MinMemoryPerGPU {
+ errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Insufficient GPU memory: Instance type %s provides %d per GPU, but preset %s requires at least %d per GPU", instanceType, skuConfig.GPUMem, presetName, presetReq.MinMemoryPerGPU), "instanceType"))
+ }
+ if totalGPUMem < presetReq.MinTotalMemory {
+ errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Insufficient total GPU memory: Instance type %s has a total of %d, but preset %s requires at least %d", instanceType, totalGPUMem, presetName, presetReq.MinTotalMemory), "instanceType"))
+ }
+ } else {
+ errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported preset name %s", presetName), "presetName"))
+ }
+ } else {
+ // Check for other instancetypes pattern matches
+ if !strings.HasPrefix(instanceType, N_SERIES_PREFIX) && !strings.HasPrefix(instanceType, D_SERIES_PREFIX) {
+ errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported instance type %s. Supported SKUs: %s", instanceType, getSupportedSKUs()), "instanceType"))
+ }
+ }
+
+ // Validate labelSelector
+ if _, err := metav1.LabelSelectorAsMap(r.LabelSelector); err != nil {
+ errs = errs.Also(apis.ErrInvalidValue(err.Error(), "labelSelector"))
+ }
+
+ return errs
+}
+
func (r *ResourceSpec) validateUpdate(old *ResourceSpec) (errs *apis.FieldError) {
// We disable changing node count for now.
if r.Count != nil && old.Count != nil && *r.Count != *old.Count {
@@ -66,6 +113,11 @@ func (r *ResourceSpec) validateUpdate(old *ResourceSpec) (errs *apis.FieldError)
}
func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) {
+ presetName := strings.ToLower(string(i.Preset.Name))
+ // Validate preset name
+ if !isValidPreset(presetName) {
+ errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported preset name %s", presetName), "presetName"))
+ }
if i.Preset != nil && i.Template != nil {
errs = errs.Also(apis.ErrGeneric("preset and template cannot be set at the same time"))
}
diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go
new file mode 100644
index 000000000..39eda44a7
--- /dev/null
+++ b/api/v1alpha1/workspace_validation_test.go
@@ -0,0 +1,478 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT license.
+
+package v1alpha1
+
+import (
+ "reflect"
+ "sort"
+ "strings"
+ "testing"
+
+ v1 "k8s.io/api/core/v1"
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+)
+
+func pointerToInt(i int) *int {
+ return &i
+}
+
+func TestResourceSpecValidateCreate(t *testing.T) {
+ tests := []struct {
+ name string
+ resourceSpec *ResourceSpec
+ inferenceSpec *InferenceSpec
+ errContent string // Content expect error to include, if any
+ expectErrs bool
+ }{
+ {
+ name: "Insufficient total GPU memory",
+ resourceSpec: &ResourceSpec{
+ InstanceType: "standard_nc6",
+ Count: pointerToInt(1),
+ },
+ inferenceSpec: &InferenceSpec{
+ Preset: &PresetSpec{
+ PresetMeta: PresetMeta{
+ Name: ModelName("falcon-7b"),
+ },
+ },
+ },
+ errContent: "Insufficient total GPU memory",
+ expectErrs: true,
+ },
+
+ {
+ name: "Insufficient number of GPUs",
+ resourceSpec: &ResourceSpec{
+ InstanceType: "standard_nc24ads_a100_v4",
+ Count: pointerToInt(1),
+ },
+ inferenceSpec: &InferenceSpec{
+ Preset: &PresetSpec{
+ PresetMeta: PresetMeta{
+ Name: ModelName("llama-2-13b-chat"),
+ },
+ },
+ },
+ errContent: "Insufficient number of GPUs",
+ expectErrs: true,
+ },
+
+ {
+ name: "Invalid SKU",
+ resourceSpec: &ResourceSpec{
+ InstanceType: "standard_invalid_sku",
+ Count: pointerToInt(1),
+ },
+ inferenceSpec: &InferenceSpec{
+ Preset: &PresetSpec{
+ PresetMeta: PresetMeta{
+ Name: ModelName("llama-2-70b"),
+ },
+ },
+ },
+ errContent: "Unsupported instance",
+ expectErrs: true,
+ },
+
+ {
+ name: "Invalid Preset",
+ resourceSpec: &ResourceSpec{
+ InstanceType: "standard_nv12s_v3",
+ Count: pointerToInt(1),
+ },
+ inferenceSpec: &InferenceSpec{
+ Preset: &PresetSpec{
+ PresetMeta: PresetMeta{
+ Name: ModelName("invalid-preset"),
+ },
+ },
+ },
+ errContent: "Unsupported preset",
+ expectErrs: true,
+ },
+
+ {
+ name: "Invalid SKU",
+ resourceSpec: &ResourceSpec{
+ InstanceType: "standard_invalid_sku",
+ Count: pointerToInt(1),
+ },
+ inferenceSpec: &InferenceSpec{
+ Preset: &PresetSpec{
+ PresetMeta: PresetMeta{
+ Name: ModelName("llama-2-70b"),
+ },
+ },
+ },
+ errContent: "Unsupported instance",
+ expectErrs: true,
+ },
+
+ {
+ name: "N-Prefix SKU",
+ resourceSpec: &ResourceSpec{
+ InstanceType: "standard_nsku",
+ Count: pointerToInt(1),
+ },
+ inferenceSpec: &InferenceSpec{
+ Preset: &PresetSpec{
+ PresetMeta: PresetMeta{
+ Name: ModelName("llama-2-7b"),
+ },
+ },
+ },
+ errContent: "",
+ expectErrs: false,
+ },
+
+ {
+ name: "D-Prefix SKU",
+ resourceSpec: &ResourceSpec{
+ InstanceType: "standard_dsku",
+ Count: pointerToInt(1),
+ },
+ inferenceSpec: &InferenceSpec{
+ Preset: &PresetSpec{
+ PresetMeta: PresetMeta{
+ Name: ModelName("llama-2-7b"),
+ },
+ },
+ },
+ errContent: "",
+ expectErrs: false,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ errs := tc.resourceSpec.validateCreate(*tc.inferenceSpec)
+ hasErrs := errs != nil
+ if hasErrs != tc.expectErrs {
+ t.Errorf("validateCreate() errors = %v, expectErrs %v", errs, tc.expectErrs)
+ }
+
+ // If there is an error and errContent is not empty, check that the error contains the expected content.
+ if hasErrs && tc.errContent != "" {
+ errMsg := errs.Error()
+ if !strings.Contains(errMsg, tc.errContent) {
+ t.Errorf("validateCreate() error message = %v, expected to contain = %v", errMsg, tc.errContent)
+ }
+ }
+ })
+ }
+}
+
+func TestResourceSpecValidateUpdate(t *testing.T) {
+
+ tests := []struct {
+ name string
+ newResource *ResourceSpec
+ oldResource *ResourceSpec
+ errContent string // Content expected error to include, if any
+ expectErrs bool
+ }{
+ {
+ name: "Immutable Count",
+ newResource: &ResourceSpec{
+ Count: pointerToInt(10),
+ },
+ oldResource: &ResourceSpec{
+ Count: pointerToInt(5),
+ },
+ errContent: "field is immutable",
+ expectErrs: true,
+ },
+ {
+ name: "Immutable InstanceType",
+ newResource: &ResourceSpec{
+ InstanceType: "new_type",
+ },
+ oldResource: &ResourceSpec{
+ InstanceType: "old_type",
+ },
+ errContent: "field is immutable",
+ expectErrs: true,
+ },
+ {
+ name: "Immutable LabelSelector",
+ newResource: &ResourceSpec{
+ LabelSelector: &metav1.LabelSelector{MatchLabels: map[string]string{"key1": "value1"}},
+ },
+ oldResource: &ResourceSpec{
+ LabelSelector: &metav1.LabelSelector{MatchLabels: map[string]string{"key2": "value2"}},
+ },
+ errContent: "field is immutable",
+ expectErrs: true,
+ },
+ {
+ name: "Valid Update",
+ newResource: &ResourceSpec{
+ Count: pointerToInt(5),
+ InstanceType: "same_type",
+ LabelSelector: &metav1.LabelSelector{MatchLabels: map[string]string{"key": "value"}},
+ },
+ oldResource: &ResourceSpec{
+ Count: pointerToInt(5),
+ InstanceType: "same_type",
+ LabelSelector: &metav1.LabelSelector{MatchLabels: map[string]string{"key": "value"}},
+ },
+ errContent: "",
+ expectErrs: false,
+ },
+ }
+
+ // Run the tests
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ errs := tc.newResource.validateUpdate(tc.oldResource)
+ hasErrs := errs != nil
+ if hasErrs != tc.expectErrs {
+ t.Errorf("validateUpdate() errors = %v, expectErrs %v", errs, tc.expectErrs)
+ }
+
+ // If there is an error and errContent is not empty, check that the error contains the expected content.
+ if hasErrs && tc.errContent != "" {
+ errMsg := errs.Error()
+ if !strings.Contains(errMsg, tc.errContent) {
+ t.Errorf("validateUpdate() error message = %v, expected to contain = %v", errMsg, tc.errContent)
+ }
+ }
+ })
+ }
+}
+
+func TestInferenceSpecValidateCreate(t *testing.T) {
+ tests := []struct {
+ name string
+ inferenceSpec *InferenceSpec
+ errContent string // Content expected error to include, if any
+ expectErrs bool
+ }{
+ {
+ name: "Invalid Preset Name",
+ inferenceSpec: &InferenceSpec{
+ Preset: &PresetSpec{
+ PresetMeta: PresetMeta{
+ Name: ModelName("Invalid-Preset-Name"),
+ },
+ },
+ },
+ errContent: "Unsupported preset name",
+ expectErrs: true,
+ },
+ {
+ name: "Preset and Template Set",
+ inferenceSpec: &InferenceSpec{
+ Preset: &PresetSpec{
+ PresetMeta: PresetMeta{
+ Name: ModelName("falcon-7b"),
+ },
+ },
+ Template: &v1.PodTemplateSpec{}, // Assuming a non-nil TemplateSpec implies it's set
+ },
+ errContent: "preset and template cannot be set at the same time",
+ expectErrs: true,
+ },
+ {
+ name: "Private Access Without Image",
+ inferenceSpec: &InferenceSpec{
+ Preset: &PresetSpec{
+ PresetMeta: PresetMeta{
+ Name: ModelName("llama-2-7b"),
+ AccessMode: "private",
+ },
+ PresetOptions: PresetOptions{},
+ },
+ },
+ errContent: "When AccessMode is private, an image must be provided",
+ expectErrs: true,
+ },
+ {
+ name: "Valid Preset",
+ inferenceSpec: &InferenceSpec{
+ Preset: &PresetSpec{
+ PresetMeta: PresetMeta{
+ Name: ModelName("falcon-7b"),
+ AccessMode: "public",
+ },
+ },
+ },
+ errContent: "",
+ expectErrs: false,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ errs := tc.inferenceSpec.validateCreate()
+ hasErrs := errs != nil
+ if hasErrs != tc.expectErrs {
+ t.Errorf("validateCreate() errors = %v, expectErrs %v", errs, tc.expectErrs)
+ }
+
+ // If there is an error and errContent is not empty, check that the error contains the expected content.
+ if hasErrs && tc.errContent != "" {
+ errMsg := errs.Error()
+ if !strings.Contains(errMsg, tc.errContent) {
+ t.Errorf("validateCreate() error message = %v, expected to contain = %v", errMsg, tc.errContent)
+ }
+ }
+ })
+ }
+}
+
+func TestInferenceSpecValidateUpdate(t *testing.T) {
+ tests := []struct {
+ name string
+ newInference *InferenceSpec
+ oldInference *InferenceSpec
+ errContent string // Content expected error to include, if any
+ expectErrs bool
+ }{
+ {
+ name: "Preset Immutable",
+ newInference: &InferenceSpec{
+ Preset: &PresetSpec{
+ PresetMeta: PresetMeta{
+ Name: ModelName("new-preset"),
+ },
+ },
+ },
+ oldInference: &InferenceSpec{
+ Preset: &PresetSpec{
+ PresetMeta: PresetMeta{
+ Name: ModelName("old-preset"),
+ },
+ },
+ },
+ errContent: "field is immutable",
+ expectErrs: true,
+ },
+ {
+ name: "Template Unset",
+ newInference: &InferenceSpec{
+ Template: nil,
+ },
+ oldInference: &InferenceSpec{
+ Template: &v1.PodTemplateSpec{},
+ },
+ errContent: "field cannot be unset/set if it was set/unset",
+ expectErrs: true,
+ },
+ {
+ name: "Template Set",
+ newInference: &InferenceSpec{
+ Template: &v1.PodTemplateSpec{},
+ },
+ oldInference: &InferenceSpec{
+ Template: nil,
+ },
+ errContent: "field cannot be unset/set if it was set/unset",
+ expectErrs: true,
+ },
+ {
+ name: "Valid Update",
+ newInference: &InferenceSpec{
+ Template: &v1.PodTemplateSpec{},
+ },
+ oldInference: &InferenceSpec{
+ Template: &v1.PodTemplateSpec{},
+ },
+ errContent: "",
+ expectErrs: false,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ errs := tc.newInference.validateUpdate(tc.oldInference)
+ hasErrs := errs != nil
+ if hasErrs != tc.expectErrs {
+ t.Errorf("validateUpdate() errors = %v, expectErrs %v", errs, tc.expectErrs)
+ }
+
+ // If there is an error and errContent is not empty, check that the error contains the expected content.
+ if hasErrs && tc.errContent != "" {
+ errMsg := errs.Error()
+ if !strings.Contains(errMsg, tc.errContent) {
+ t.Errorf("validateUpdate() error message = %v, expected to contain = %v", errMsg, tc.errContent)
+ }
+ }
+ })
+ }
+}
+
+func TestGetSupportedSKUs(t *testing.T) {
+ tests := []struct {
+ name string
+ gpuConfigs map[string]GPUConfig
+ expectedResult []string // changed to a slice for deterministic ordering
+ }{
+ {
+ name: "no SKUs supported",
+ gpuConfigs: map[string]GPUConfig{},
+ expectedResult: []string{""},
+ },
+ {
+ name: "one SKU supported",
+ gpuConfigs: map[string]GPUConfig{
+ "standard_nc6": {SKU: "standard_nc6"},
+ },
+ expectedResult: []string{"standard_nc6"},
+ },
+ {
+ name: "multiple SKUs supported",
+ gpuConfigs: map[string]GPUConfig{
+ "standard_nc6": {SKU: "standard_nc6"},
+ "standard_nc12": {SKU: "standard_nc12"},
+ },
+ expectedResult: []string{"standard_nc6", "standard_nc12"},
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ SupportedGPUConfigs = tc.gpuConfigs
+
+ resultSlice := strings.Split(getSupportedSKUs(), ", ")
+ sort.Strings(resultSlice)
+
+ // Sort the expectedResult for comparison
+ expectedResultSlice := tc.expectedResult
+ sort.Strings(expectedResultSlice)
+
+ if !reflect.DeepEqual(resultSlice, expectedResultSlice) {
+ t.Errorf("getSupportedSKUs() = %v, want %v", resultSlice, expectedResultSlice)
+ }
+ })
+ }
+}
+
+func TestIsValidPreset(t *testing.T) {
+ tests := []struct {
+ name string
+ preset string
+ expectValid bool
+ }{
+ {
+ name: "valid preset",
+ preset: "falcon-7b",
+ expectValid: true,
+ },
+ {
+ name: "invalid preset",
+ preset: "nonexistent-preset",
+ expectValid: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if valid := isValidPreset(test.preset); valid != test.expectValid {
+ t.Errorf("isValidPreset(%s) = %v, want %v", test.preset, valid, test.expectValid)
+ }
+ })
+ }
+}
diff --git a/api/v1alpha1/zz_generated.deepcopy.go b/api/v1alpha1/zz_generated.deepcopy.go
index bfdc825b4..ef60d7362 100644
--- a/api/v1alpha1/zz_generated.deepcopy.go
+++ b/api/v1alpha1/zz_generated.deepcopy.go
@@ -14,6 +14,26 @@ import (
runtime "k8s.io/apimachinery/pkg/runtime"
)
+// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
+func (in *GPUConfig) DeepCopyInto(out *GPUConfig) {
+ *out = *in
+ if in.SupportedOS != nil {
+ in, out := &in.SupportedOS, &out.SupportedOS
+ *out = make([]string, len(*in))
+ copy(*out, *in)
+ }
+}
+
+// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new GPUConfig.
+func (in *GPUConfig) DeepCopy() *GPUConfig {
+ if in == nil {
+ return nil
+ }
+ out := new(GPUConfig)
+ in.DeepCopyInto(out)
+ return out
+}
+
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *InferenceSpec) DeepCopyInto(out *InferenceSpec) {
*out = *in
@@ -74,6 +94,21 @@ func (in *PresetOptions) DeepCopy() *PresetOptions {
return out
}
+// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
+func (in *PresetRequirements) DeepCopyInto(out *PresetRequirements) {
+ *out = *in
+}
+
+// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new PresetRequirements.
+func (in *PresetRequirements) DeepCopy() *PresetRequirements {
+ if in == nil {
+ return nil
+ }
+ out := new(PresetRequirements)
+ in.DeepCopyInto(out)
+ return out
+}
+
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *PresetSpec) DeepCopyInto(out *PresetSpec) {
*out = *in
diff --git a/docs/img/arch.png b/docs/img/arch.png
new file mode 100644
index 000000000..94c399faf
Binary files /dev/null and b/docs/img/arch.png differ