Skip to content

Commit

Permalink
feat: support vllm in controller
Browse files Browse the repository at this point in the history
- set vllm as the default runtime

Signed-off-by: jerryzhuang <[email protected]>
  • Loading branch information
zhuangqh committed Oct 17, 2024
1 parent 6481b76 commit 936993c
Show file tree
Hide file tree
Showing 24 changed files with 534 additions and 234 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,11 @@ azure-karpenter-helm: ## Update Azure client env vars and settings in helm valu
##@ Build
.PHONY: build
build: manifests generate fmt vet ## Build manager binary.
go build -o bin/manager cmd/*.go
go build -o bin/manager cmd/workspace/*.go

.PHONY: run
run: manifests generate fmt vet ## Run a controller from your host.
go run ./cmd/main.go
go run ./cmd/workspace/main.go

##@ Build Dependencies
## Location to install dependencies to
Expand Down
23 changes: 23 additions & 0 deletions api/v1alpha1/labels.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

package v1alpha1

import "github.com/azure/kaito/pkg/model"

const (

// Non-prefixed labels/annotations are reserved for end-use.
Expand All @@ -27,4 +29,25 @@ const (

// WorkspaceRevisionAnnotation is the Annotations for revision number
WorkspaceRevisionAnnotation = "workspace.kaito.io/revision"

// AnnotationWorkspaceBackend is the annotation for backend selection.
AnnotationWorkspaceBackend = KAITOPrefix + "backend"
)

// GetWorkspaceBackendName returns the runtime name of the workspace.
func GetWorkspaceBackendName(ws *Workspace) model.BackendName {
if ws == nil {
panic("workspace is nil")
}
runtime := model.BackendNameVLLM

name := ws.Annotations[AnnotationWorkspaceBackend]
switch name {
case string(model.BackendNameHuggingfaceTransformers):
runtime = model.BackendNameHuggingfaceTransformers
case string(model.BackendNameVLLM):
runtime = model.BackendNameVLLM
}

return runtime
}
4 changes: 2 additions & 2 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace stri
// Currently require a preset to specified, in future we can consider defining a template
if r.Preset == nil {
errs = errs.Also(apis.ErrMissingField("Preset"))
} else if presetName := string(r.Preset.Name); !utils.IsValidPreset(presetName) {
} else if presetName := string(r.Preset.Name); !plugin.IsValidPreset(presetName) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported tuning preset name %s", presetName), "presetName"))
}
return errs
Expand Down Expand Up @@ -407,7 +407,7 @@ func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) {
if i.Preset != nil {
presetName := string(i.Preset.Name)
// Validate preset name
if !utils.IsValidPreset(presetName) {
if !plugin.IsValidPreset(presetName) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported inference preset name %s", presetName), "presetName"))
}
// Validate private preset has private image specified
Expand Down
5 changes: 5 additions & 0 deletions api/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions charts/kaito/ragengine/crds/kaito.sh_ragengines.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ spec:
preferredNodes:
description: |-
PreferredNodes is an optional node list specified by the user.
If a node in the list does not have the required labels or
the required instanceType, it will be ignored.
If a node in the list does not have the required labels, it will be ignored.
items:
type: string
type: array
Expand Down
3 changes: 1 addition & 2 deletions charts/kaito/workspace/crds/kaito.sh_workspaces.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,7 @@ spec:
preferredNodes:
description: |-
PreferredNodes is an optional node list specified by the user.
If a node in the list does not have the required labels or
the required instanceType, it will be ignored.
If a node in the list does not have the required labels, it will be ignored.
items:
type: string
type: array
Expand Down
3 changes: 1 addition & 2 deletions config/crd/bases/kaito.sh_ragengines.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ spec:
preferredNodes:
description: |-
PreferredNodes is an optional node list specified by the user.
If a node in the list does not have the required labels or
the required instanceType, it will be ignored.
If a node in the list does not have the required labels, it will be ignored.
items:
type: string
type: array
Expand Down
3 changes: 1 addition & 2 deletions config/crd/bases/kaito.sh_workspaces.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,7 @@ spec:
preferredNodes:
description: |-
PreferredNodes is an optional node list specified by the user.
If a node in the list does not have the required labels or
the required instanceType, it will be ignored.
If a node in the list does not have the required labels, it will be ignored.
items:
type: string
type: array
Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/workspace_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1a
} else if apierrors.IsNotFound(err) {
var workloadObj client.Object
// Need to create a new workload
workloadObj, err = inference.CreatePresetInference(ctx, wObj, revisionStr, inferenceParam, model.SupportDistributedInference(), c.Client)
workloadObj, err = inference.CreatePresetInference(ctx, wObj, revisionStr, model, c.Client)
if err != nil {
return
}
Expand Down
101 changes: 46 additions & 55 deletions pkg/inference/preset-inferences.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ import (
"context"
"fmt"
"os"
"strconv"

"github.com/azure/kaito/pkg/utils"
"github.com/azure/kaito/pkg/utils/consts"

"github.com/azure/kaito/api/v1alpha1"
kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1"
"github.com/azure/kaito/pkg/model"
"github.com/azure/kaito/pkg/resources"
Expand All @@ -22,9 +22,8 @@ import (
)

const (
ProbePath = "/healthz"
Port5000 = int32(5000)
InferenceFile = "inference_api.py"
ProbePath = "/healthz"
Port5000 = int32(5000)
)

var (
Expand Down Expand Up @@ -71,27 +70,27 @@ var (
)

func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient client.Client, wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) error {
existingService := &corev1.Service{}
err := resources.GetResource(ctx, wObj.Name, wObj.Namespace, kubeClient, existingService)
if err != nil {
return err
}

nodes := *wObj.Resource.Count
inferenceObj.TorchRunParams["nnodes"] = strconv.Itoa(nodes)
inferenceObj.TorchRunParams["nproc_per_node"] = strconv.Itoa(inferenceObj.WorldSize / nodes)
if nodes > 1 {
inferenceObj.TorchRunParams["node_rank"] = "$(echo $HOSTNAME | grep -o '[^-]*$')"
inferenceObj.TorchRunParams["master_addr"] = existingService.Spec.ClusterIP
inferenceObj.TorchRunParams["master_port"] = "29500"
}
if inferenceObj.TorchRunRdzvParams != nil {
inferenceObj.TorchRunRdzvParams["max_restarts"] = "3"
inferenceObj.TorchRunRdzvParams["rdzv_id"] = "job"
inferenceObj.TorchRunRdzvParams["rdzv_backend"] = "c10d"
inferenceObj.TorchRunRdzvParams["rdzv_endpoint"] =
fmt.Sprintf("%s-0.%s-headless.%s.svc.cluster.local:29500", wObj.Name, wObj.Name, wObj.Namespace)
}
// existingService := &corev1.Service{}
// err := resources.GetResource(ctx, wObj.Name, wObj.Namespace, kubeClient, existingService)
// if err != nil {
// return err
// }

// nodes := *wObj.Resource.Count
// inferenceObj.TorchRunParams["nnodes"] = strconv.Itoa(nodes)
// inferenceObj.TorchRunParams["nproc_per_node"] = strconv.Itoa(inferenceObj.WorldSize / nodes)
// if nodes > 1 {
// inferenceObj.TorchRunParams["node_rank"] = "$(echo $HOSTNAME | grep -o '[^-]*$')"
// inferenceObj.TorchRunParams["master_addr"] = existingService.Spec.ClusterIP
// inferenceObj.TorchRunParams["master_port"] = "29500"
// }
// if inferenceObj.TorchRunRdzvParams != nil {
// inferenceObj.TorchRunRdzvParams["max_restarts"] = "3"
// inferenceObj.TorchRunRdzvParams["rdzv_id"] = "job"
// inferenceObj.TorchRunRdzvParams["rdzv_backend"] = "c10d"
// inferenceObj.TorchRunRdzvParams["rdzv_endpoint"] =
// fmt.Sprintf("%s-0.%s-headless.%s.svc.cluster.local:29500", wObj.Name, wObj.Name, wObj.Namespace)
// }
return nil
}

Expand All @@ -114,14 +113,17 @@ func GetInferenceImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Work
}

func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, revisionNum string,
inferenceObj *model.PresetParam, supportDistributedInference bool, kubeClient client.Client) (client.Object, error) {
if inferenceObj.TorchRunParams != nil && supportDistributedInference {
if err := updateTorchParamsForDistributedInference(ctx, kubeClient, workspaceObj, inferenceObj); err != nil {
model model.Model, kubeClient client.Client) (client.Object, error) {
inferenceParam := model.GetInferenceParameters().DeepCopy()

if model.SupportDistributedInference() {
if err := updateTorchParamsForDistributedInference(ctx, kubeClient, workspaceObj, inferenceParam); err != nil { //
klog.ErrorS(err, "failed to update torch params", "workspace", workspaceObj)
return nil, err
}
}

// additional volume
var volumes []corev1.Volume
var volumeMounts []corev1.VolumeMount
shmVolume, shmVolumeMount := utils.ConfigSHMVolume(*workspaceObj.Resource.Count)
Expand All @@ -131,24 +133,35 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work
if shmVolumeMount.Name != "" {
volumeMounts = append(volumeMounts, shmVolumeMount)
}

if len(workspaceObj.Inference.Adapters) > 0 {
adapterVolume, adapterVolumeMount := utils.ConfigAdapterVolume()
volumes = append(volumes, adapterVolume)
volumeMounts = append(volumeMounts, adapterVolumeMount)
}

// resource requirements
skuNumGPUs, err := utils.GetSKUNumGPUs(ctx, kubeClient, workspaceObj.Status.WorkerNodes,
workspaceObj.Resource.InstanceType, inferenceObj.GPUCountRequirement)
workspaceObj.Resource.InstanceType, inferenceParam.GPUCountRequirement)
if err != nil {
return nil, fmt.Errorf("failed to get SKU num GPUs: %v", err)
}
resourceReq := corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceName(resources.CapacityNvidiaGPU): resource.MustParse(skuNumGPUs),
},
Limits: corev1.ResourceList{
corev1.ResourceName(resources.CapacityNvidiaGPU): resource.MustParse(skuNumGPUs),
},
}

// inference command
backendName := v1alpha1.GetWorkspaceBackendName(workspaceObj)
commands := inferenceParam.GetInferenceCommand(backendName)

commands, resourceReq := prepareInferenceParameters(ctx, inferenceObj, skuNumGPUs)
image, imagePullSecrets := GetInferenceImageInfo(ctx, workspaceObj, inferenceObj)
image, imagePullSecrets := GetInferenceImageInfo(ctx, workspaceObj, inferenceParam)

var depObj client.Object
if supportDistributedInference {
if model.SupportDistributedInference() {
depObj = resources.GenerateStatefulSetManifest(ctx, workspaceObj, image, imagePullSecrets, *workspaceObj.Resource.Count, commands,
containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volumes, volumeMounts)
} else {
Expand All @@ -161,25 +174,3 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work
}
return depObj, nil
}

// prepareInferenceParameters builds a PyTorch command:
// torchrun <TORCH_PARAMS> <OPTIONAL_RDZV_PARAMS> baseCommand <MODEL_PARAMS>
// and sets the GPU resources required for inference.
// Returns the command and resource configuration.
func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetParam, skuNumGPUs string) ([]string, corev1.ResourceRequirements) {
torchCommand := utils.BuildCmdStr(inferenceObj.BaseCommand, inferenceObj.TorchRunParams)
torchCommand = utils.BuildCmdStr(torchCommand, inferenceObj.TorchRunRdzvParams)
modelCommand := utils.BuildCmdStr(InferenceFile, inferenceObj.ModelRunParams)
commands := utils.ShellCmd(torchCommand + " " + modelCommand)

resourceRequirements := corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceName(resources.CapacityNvidiaGPU): resource.MustParse(skuNumGPUs),
},
Limits: corev1.ResourceList{
corev1.ResourceName(resources.CapacityNvidiaGPU): resource.MustParse(skuNumGPUs),
},
}

return commands, resourceRequirements
}
16 changes: 4 additions & 12 deletions pkg/inference/preset-inferences_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"github.com/azure/kaito/api/v1alpha1"
"github.com/azure/kaito/pkg/utils/test"

"github.com/azure/kaito/pkg/model"
"github.com/azure/kaito/pkg/utils/plugin"
"github.com/stretchr/testify/mock"
appsv1 "k8s.io/api/apps/v1"
Expand Down Expand Up @@ -46,7 +45,7 @@ func TestCreatePresetInference(t *testing.T) {
workload: "Deployment",
// No BaseCommand, TorchRunParams, TorchRunRdzvParams, or ModelRunParams
// So expected cmd consists of shell command and inference file
expectedCmd: "/bin/sh -c inference_api.py",
expectedCmd: "/bin/sh -c inference_api_vllm.py",
hasAdapters: false,
},

Expand All @@ -58,7 +57,7 @@ func TestCreatePresetInference(t *testing.T) {
c.On("Create", mock.IsType(context.TODO()), mock.IsType(&appsv1.StatefulSet{}), mock.Anything).Return(nil)
},
workload: "StatefulSet",
expectedCmd: "/bin/sh -c inference_api.py",
expectedCmd: "/bin/sh -c inference_api_vllm.py",
hasAdapters: false,
},

Expand All @@ -69,7 +68,7 @@ func TestCreatePresetInference(t *testing.T) {
c.On("Create", mock.IsType(context.TODO()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil)
},
workload: "Deployment",
expectedCmd: "/bin/sh -c inference_api.py",
expectedCmd: "/bin/sh -c inference_api_vllm.py",
hasAdapters: true,
expectedVolume: "adapter-volume",
},
Expand All @@ -96,15 +95,8 @@ func TestCreatePresetInference(t *testing.T) {
}
}

useHeadlessSvc := false

var inferenceObj *model.PresetParam
model := plugin.KaitoModelRegister.MustGet(tc.modelName)
inferenceObj = model.GetInferenceParameters()

if strings.Contains(tc.modelName, "distributed") {
useHeadlessSvc = true
}
svc := &corev1.Service{
ObjectMeta: v1.ObjectMeta{
Name: workspace.Name,
Expand All @@ -116,7 +108,7 @@ func TestCreatePresetInference(t *testing.T) {
}
mockClient.CreateOrUpdateObjectInMap(svc)

createdObject, _ := CreatePresetInference(context.TODO(), workspace, test.MockWorkspaceWithPresetHash, inferenceObj, useHeadlessSvc, mockClient)
createdObject, _ := CreatePresetInference(context.TODO(), workspace, test.MockWorkspaceWithPresetHash, model, mockClient)
createdWorkload := ""
switch createdObject.(type) {
case *appsv1.Deployment:
Expand Down
Loading

0 comments on commit 936993c

Please sign in to comment.