Skip to content

Commit

Permalink
feat: support vllm in controller
Browse files Browse the repository at this point in the history
- set vllm as the default runtime by a featureflag

Signed-off-by: jerryzhuang <[email protected]>
  • Loading branch information
zhuangqh committed Dec 2, 2024
1 parent 21056a1 commit 933107e
Show file tree
Hide file tree
Showing 26 changed files with 892 additions and 258 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ GINKGO_FOCUS ?=
GINKGO_SKIP ?=
GINKGO_NODES ?= 2
GINKGO_NO_COLOR ?= false
GINKGO_TIMEOUT ?= 180m
GINKGO_ARGS ?= -focus="$(GINKGO_FOCUS)" -skip="$(GINKGO_SKIP)" -nodes=$(GINKGO_NODES) -no-color=$(GINKGO_NO_COLOR) -timeout=$(GINKGO_TIMEOUT) --fail-fast
GINKGO_TIMEOUT ?= 120m
GINKGO_ARGS ?= -focus="$(GINKGO_FOCUS)" -skip="$(GINKGO_SKIP)" -nodes=$(GINKGO_NODES) -no-color=$(GINKGO_NO_COLOR) --output-interceptor-mode=none -timeout=$(GINKGO_TIMEOUT)

$(E2E_TEST):
(cd test/e2e && go test -c . -o $(E2E_TEST))
Expand Down
30 changes: 30 additions & 0 deletions api/v1alpha1/labels.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

package v1alpha1

import (
"github.com/kaito-project/kaito/pkg/featuregates"
"github.com/kaito-project/kaito/pkg/model"
"github.com/kaito-project/kaito/pkg/utils/consts"
)

const (

// Non-prefixed labels/annotations are reserved for end-use.
Expand Down Expand Up @@ -30,4 +36,28 @@ const (

// RAGEngineRevisionAnnotation is the Annotations for revision number
RAGEngineRevisionAnnotation = "ragengine.kaito.io/revision"

// AnnotationWorkspaceRuntime is the annotation for runtime selection.
AnnotationWorkspaceRuntime = KAITOPrefix + "runtime"
)

// GetWorkspaceRuntimeName returns the runtime name of the workspace.
func GetWorkspaceRuntimeName(ws *Workspace) model.RuntimeName {
if ws == nil {
panic("workspace is nil")
}
runtime := model.RuntimeNameHuggingfaceTransformers
if featuregates.FeatureGates[consts.FeatureFlagVLLM] {
runtime = model.RuntimeNameVLLM
}

name := ws.Annotations[AnnotationWorkspaceRuntime]
switch name {
case string(model.RuntimeNameHuggingfaceTransformers):
runtime = model.RuntimeNameHuggingfaceTransformers
case string(model.RuntimeNameVLLM):
runtime = model.RuntimeNameVLLM
}

return runtime
}
4 changes: 2 additions & 2 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace stri
// Currently require a preset to specified, in future we can consider defining a template
if r.Preset == nil {
errs = errs.Also(apis.ErrMissingField("Preset"))
} else if presetName := string(r.Preset.Name); !utils.IsValidPreset(presetName) {
} else if presetName := string(r.Preset.Name); !plugin.IsValidPreset(presetName) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported tuning preset name %s", presetName), "presetName"))
}
return errs
Expand Down Expand Up @@ -404,7 +404,7 @@ func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) {
if i.Preset != nil {
presetName := string(i.Preset.Name)
// Validate preset name
if !utils.IsValidPreset(presetName) {
if !plugin.IsValidPreset(presetName) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported inference preset name %s", presetName), "presetName"))
}
// Validate private preset has private image specified
Expand Down
1 change: 1 addition & 0 deletions pkg/featuregates/featuregates.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ var (
// FeatureGates is a map that holds the feature gates and their default values for Kaito.
FeatureGates = map[string]bool{
consts.FeatureFlagKarpenter: false,
consts.FeatureFlagVLLM: false,
// Add more feature gates here
}
)
Expand Down
145 changes: 132 additions & 13 deletions pkg/model/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ package model

import (
"time"

"github.com/kaito-project/kaito/pkg/utils"
)

type Model interface {
Expand All @@ -13,23 +15,140 @@ type Model interface {
SupportTuning() bool
}

// RuntimeName is LLM runtime name.
type RuntimeName string

const (
RuntimeNameHuggingfaceTransformers RuntimeName = "transformers"
RuntimeNameVLLM RuntimeName = "vllm"
)

// PresetParam defines the preset inference parameters for a model.
type PresetParam struct {
ModelFamilyName string // The name of the model family.
ImageAccessMode string // Defines where the Image is Public or Private.
DiskStorageRequirement string // Disk storage requirements for the model.
GPUCountRequirement string // Number of GPUs required for the Preset. Used for inference.
TotalGPUMemoryRequirement string // Total GPU memory required for the Preset. Used for inference.
PerGPUMemoryRequirement string // GPU memory required per GPU. Used for inference.
TuningPerGPUMemoryRequirement map[string]int // Min GPU memory per tuning method (batch size 1). Used for tuning.
TorchRunParams map[string]string // Parameters for configuring the torchrun command.
TorchRunRdzvParams map[string]string // Optional rendezvous parameters for distributed training/inference using torchrun (elastic).
BaseCommand string // The initial command (e.g., 'torchrun', 'accelerate launch') used in the command line.
ModelRunParams map[string]string // Parameters for running the model training/inference.
Tag string // The model image tag
ModelFamilyName string // The name of the model family.
ImageAccessMode string // Defines where the Image is Public or Private.

DiskStorageRequirement string // Disk storage requirements for the model.
GPUCountRequirement string // Number of GPUs required for the Preset. Used for inference.
TotalGPUMemoryRequirement string // Total GPU memory required for the Preset. Used for inference.
PerGPUMemoryRequirement string // GPU memory required per GPU. Used for inference.
TuningPerGPUMemoryRequirement map[string]int // Min GPU memory per tuning method (batch size 1). Used for tuning.
WorldSize int // Defines the number of processes required for distributed inference.

RuntimeParam

// ReadinessTimeout defines the maximum duration for creating the workload.
// This timeout accommodates the size of the image, ensuring pull completion
// even under slower network conditions or unforeseen delays.
ReadinessTimeout time.Duration
WorldSize int // Defines the number of processes required for distributed inference.
Tag string // The model image tag
}

// RuntimeParam defines the llm runtime parameters.
type RuntimeParam struct {
Transformers HuggingfaceTransformersParam
VLLM VLLMParam
}

type HuggingfaceTransformersParam struct {
BaseCommand string // The initial command (e.g., 'torchrun', 'accelerate launch') used in the command line.
TorchRunParams map[string]string // Parameters for configuring the torchrun command.
TorchRunRdzvParams map[string]string // Optional rendezvous parameters for distributed training/inference using torchrun (elastic).
InferenceMainFile string // The main file for inference.
ModelRunParams map[string]string // Parameters for running the model training/inference.
}

type VLLMParam struct {
BaseCommand string
// The model name used in the openai serving API.
// see https://platform.openai.com/docs/api-reference/chat/create#chat-create-model.
ModelName string
// Parameters for distributed inference.
DistributionParams map[string]string
// Parameters for running the model training/inference.
ModelRunParams map[string]string
}

func (p *PresetParam) DeepCopy() *PresetParam {
if p == nil {
return nil
}
out := new(PresetParam)
*out = *p
out.RuntimeParam = p.RuntimeParam.DeepCopy()
out.TuningPerGPUMemoryRequirement = make(map[string]int, len(p.TuningPerGPUMemoryRequirement))
for k, v := range p.TuningPerGPUMemoryRequirement {
out.TuningPerGPUMemoryRequirement[k] = v
}
return out
}

func (rp *RuntimeParam) DeepCopy() RuntimeParam {
if rp == nil {
return RuntimeParam{}
}
out := RuntimeParam{}
out.Transformers = rp.Transformers.DeepCopy()
out.VLLM = rp.VLLM.DeepCopy()
return out
}

func (h *HuggingfaceTransformersParam) DeepCopy() HuggingfaceTransformersParam {
if h == nil {
return HuggingfaceTransformersParam{}
}
out := HuggingfaceTransformersParam{}
out.BaseCommand = h.BaseCommand
out.InferenceMainFile = h.InferenceMainFile
out.TorchRunParams = make(map[string]string, len(h.TorchRunParams))
for k, v := range h.TorchRunParams {
out.TorchRunParams[k] = v
}
out.TorchRunRdzvParams = make(map[string]string, len(h.TorchRunRdzvParams))
for k, v := range h.TorchRunRdzvParams {
out.TorchRunRdzvParams[k] = v
}
out.ModelRunParams = make(map[string]string, len(h.ModelRunParams))
for k, v := range h.ModelRunParams {
out.ModelRunParams[k] = v
}
return out
}

func (v *VLLMParam) DeepCopy() VLLMParam {
if v == nil {
return VLLMParam{}
}
out := VLLMParam{}
out.BaseCommand = v.BaseCommand
out.ModelName = v.ModelName
out.DistributionParams = make(map[string]string, len(v.DistributionParams))
for k, v := range v.DistributionParams {
out.DistributionParams[k] = v
}
out.ModelRunParams = make(map[string]string, len(v.ModelRunParams))
for k, v := range v.ModelRunParams {
out.ModelRunParams[k] = v
}
return out
}

// builds the container command:
// eg. torchrun <TORCH_PARAMS> <OPTIONAL_RDZV_PARAMS> baseCommand <MODEL_PARAMS>
func (p *PresetParam) GetInferenceCommand(runtime RuntimeName, skuNumGPUs string) []string {
switch runtime {
case RuntimeNameHuggingfaceTransformers:
torchCommand := utils.BuildCmdStr(p.Transformers.BaseCommand, p.Transformers.TorchRunParams, p.Transformers.TorchRunRdzvParams)
modelCommand := utils.BuildCmdStr(p.Transformers.InferenceMainFile, p.Transformers.ModelRunParams)
return utils.ShellCmd(torchCommand + " " + modelCommand)
case RuntimeNameVLLM:
if p.VLLM.ModelName != "" {
p.VLLM.ModelRunParams["served-model-name"] = p.VLLM.ModelName
}
p.VLLM.ModelRunParams["tensor-parallel-size"] = skuNumGPUs
modelCommand := utils.BuildCmdStr(p.VLLM.BaseCommand, p.VLLM.ModelRunParams)
return utils.ShellCmd(modelCommand)
default:
return nil
}
}
10 changes: 4 additions & 6 deletions pkg/utils/common-preset.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
package utils

import (
"github.com/kaito-project/kaito/pkg/utils/plugin"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
)

const (
Expand Down Expand Up @@ -66,12 +66,14 @@ func ConfigSHMVolume(instanceCount int) (corev1.Volume, corev1.VolumeMount) {

// Signifies multinode inference requirement
if instanceCount > 1 {
size := resource.MustParse("4Gi")
// Append share memory volume to any existing volumes
volume = corev1.Volume{
Name: "dshm",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{
Medium: "Memory",
Medium: "Memory",
SizeLimit: &size,
},
},
}
Expand Down Expand Up @@ -150,7 +152,3 @@ func ConfigAdapterVolume() (corev1.Volume, corev1.VolumeMount) {
}
return volume, volumeMount
}

func IsValidPreset(preset string) bool {
return plugin.KaitoModelRegister.Has(preset)
}
14 changes: 8 additions & 6 deletions pkg/utils/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,15 @@ func MergeConfigMaps(baseMap, overrideMap map[string]string) map[string]string {
return merged
}

func BuildCmdStr(baseCommand string, runParams map[string]string) string {
func BuildCmdStr(baseCommand string, runParams ...map[string]string) string {
updatedBaseCommand := baseCommand
for key, value := range runParams {
if value == "" {
updatedBaseCommand = fmt.Sprintf("%s --%s", updatedBaseCommand, key)
} else {
updatedBaseCommand = fmt.Sprintf("%s --%s=%s", updatedBaseCommand, key, value)
for _, runParam := range runParams {
for key, value := range runParam {
if value == "" {
updatedBaseCommand = fmt.Sprintf("%s --%s", updatedBaseCommand, key)
} else {
updatedBaseCommand = fmt.Sprintf("%s --%s=%s", updatedBaseCommand, key, value)
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/utils/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package utils

import (
"context"
"sigs.k8s.io/controller-runtime/pkg/client"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -12,6 +11,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/kubernetes/scheme"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/fake"
)

Expand Down
5 changes: 4 additions & 1 deletion pkg/utils/consts/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ const (
// RAGEngineFinalizer is used to make sure that ragengine controller handles garbage collection.
RAGEngineFinalizer = "ragengine.finalizer.kaito.sh"
DefaultReleaseNamespaceEnvVar = "RELEASE_NAMESPACE"
FeatureFlagKarpenter = "Karpenter"
AzureCloudName = "azure"
AWSCloudName = "aws"
GPUString = "gpu"
Expand All @@ -20,6 +19,10 @@ const (
GiBToBytes = 1024 * 1024 * 1024 // Conversion factor from GiB to bytes
NvidiaGPU = "nvidia.com/gpu"

// Feature flags
FeatureFlagKarpenter = "Karpenter"
FeatureFlagVLLM = "vLLM"

// Nodeclaim related consts
KaitoNodePoolName = "kaito"
LabelNodePool = "karpenter.sh/nodepool"
Expand Down
4 changes: 4 additions & 0 deletions pkg/utils/plugin/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,7 @@ func (reg *ModelRegister) Has(name string) bool {
_, ok := reg.models[name]
return ok
}

func IsValidPreset(preset string) bool {
return KaitoModelRegister.Has(preset)
}
24 changes: 20 additions & 4 deletions pkg/utils/test/testModel.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,16 @@ type testModel struct{}
func (*testModel) GetInferenceParameters() *model.PresetParam {
return &model.PresetParam{
GPUCountRequirement: "1",
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: "python3",
RuntimeParam: model.RuntimeParam{
VLLM: model.VLLMParam{
BaseCommand: "python3 /workspace/vllm/inference_api.py",
},
Transformers: model.HuggingfaceTransformersParam{
BaseCommand: "accelerate launch",
InferenceMainFile: "/workspace/tfs/inference_api.py",
},
},
ReadinessTimeout: time.Duration(30) * time.Minute,
}
}
func (*testModel) GetTuningParameters() *model.PresetParam {
Expand All @@ -37,8 +45,16 @@ type testDistributedModel struct{}
func (*testDistributedModel) GetInferenceParameters() *model.PresetParam {
return &model.PresetParam{
GPUCountRequirement: "1",
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: "python3",
RuntimeParam: model.RuntimeParam{
VLLM: model.VLLMParam{
BaseCommand: "python3 /workspace/vllm/inference_api.py",
},
Transformers: model.HuggingfaceTransformersParam{
BaseCommand: "accelerate launch",
InferenceMainFile: "/workspace/tfs/inference_api.py",
},
},
ReadinessTimeout: time.Duration(30) * time.Minute,
}
}
func (*testDistributedModel) GetTuningParameters() *model.PresetParam {
Expand Down
Loading

0 comments on commit 933107e

Please sign in to comment.