Skip to content

Commit

Permalink
feat: RAG engine deployment creation (#660)
Browse files Browse the repository at this point in the history
**Reason for Change**:
Generate and create RAGengine deployment - part 1 

**Requirements**

- [ ] added unit tests and e2e tests (if applicable).

**Issue Fixed**:
<!-- If this PR fixes GitHub issue 4321, add "Fixes #4321" to the next
line. -->

**Notes for Reviewers**:

Signed-off-by: Bangqi Zhu <[email protected]>
Co-authored-by: Bangqi Zhu <[email protected]>
  • Loading branch information
bangqipropel and Bangqi Zhu authored Nov 12, 2024
1 parent fcd5d1c commit cafb947
Show file tree
Hide file tree
Showing 9 changed files with 543 additions and 0 deletions.
8 changes: 8 additions & 0 deletions api/v1alpha1/condition_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ const (
// WorkspaceConditionTypeInferenceStatus is the state when Inference service has been ready.
WorkspaceConditionTypeInferenceStatus = ConditionType("InferenceReady")

// RAGEneineConditionTypeServiceStatus is the state when service has been ready.
RAGEneineConditionTypeServiceStatus = ConditionType("ServiceReady")

// RAGConditionTypeServiceStatus is the state when RAG Engine service has been ready.
RAGConditionTypeServiceStatus = ConditionType("RAGEngineServiceReady")

// WorkspaceConditionTypeTuningJobStatus is the state when the tuning job starts normally.
WorkspaceConditionTypeTuningJobStatus ConditionType = ConditionType("JobStarted")

Expand All @@ -32,4 +38,6 @@ const (
//For inference, the "True" condition means the inference service is ready to serve requests.
//For fine tuning, the "True" condition means the tuning job completes successfully.
WorkspaceConditionTypeSucceeded ConditionType = ConditionType("WorkspaceSucceeded")

RAGEngineConditionTypeSucceeded ConditionType = ConditionType("RAGEngineSucceeded")
)
2 changes: 2 additions & 0 deletions charts/kaito/ragengine/templates/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ spec:
valueFrom:
fieldRef:
fieldPath: metadata.namespace
- name: CLOUD_PROVIDER
value: {{ .Values.cloudProviderName }}
ports:
- name: http-metrics
containerPort: 8080
Expand Down
2 changes: 2 additions & 0 deletions charts/kaito/ragengine/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@ resources:
nodeSelector: {}
tolerations: []
affinity: {}
# Values can be "azure" or "aws"
cloudProviderName: "azure"
113 changes: 113 additions & 0 deletions pkg/ragengine/controllers/preset-rag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package controllers

import (
"context"
"fmt"

"github.com/kaito-project/kaito/pkg/utils"
"github.com/kaito-project/kaito/pkg/utils/consts"

kaitov1alpha1 "github.com/kaito-project/kaito/api/v1alpha1"
"github.com/kaito-project/kaito/pkg/ragengine/manifests"
"github.com/kaito-project/kaito/pkg/utils/resources"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
"k8s.io/apimachinery/pkg/util/intstr"
"sigs.k8s.io/controller-runtime/pkg/client"
)

const (
ProbePath = "/health"
Port5000 = int32(5000)
)

var (
containerPorts = []corev1.ContainerPort{{
ContainerPort: Port5000,
},
}

livenessProbe = &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Port: intstr.FromInt(5000),
Path: ProbePath,
},
},
InitialDelaySeconds: 600, // 10 minutes
PeriodSeconds: 10,
}

readinessProbe = &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Port: intstr.FromInt(5000),
Path: ProbePath,
},
},
InitialDelaySeconds: 30,
PeriodSeconds: 10,
}

tolerations = []corev1.Toleration{
{
Effect: corev1.TaintEffectNoSchedule,
Operator: corev1.TolerationOpExists,
Key: resources.CapacityNvidiaGPU,
},
{
Effect: corev1.TaintEffectNoSchedule,
Value: consts.GPUString,
Key: consts.SKUString,
Operator: corev1.TolerationOpEqual,
},
}
)

func CreatePresetRAG(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine, revisionNum string, kubeClient client.Client) (client.Object, error) {
var volumes []corev1.Volume
var volumeMounts []corev1.VolumeMount

shmVolume, shmVolumeMount := utils.ConfigSHMVolume(*ragEngineObj.Spec.Compute.Count)
if shmVolume.Name != "" {
volumes = append(volumes, shmVolume)
}
if shmVolumeMount.Name != "" {
volumeMounts = append(volumeMounts, shmVolumeMount)
}

var resourceReq corev1.ResourceRequirements

if ragEngineObj.Spec.Embedding.Local != nil {
skuNumGPUs, err := utils.GetSKUNumGPUs(ctx, kubeClient, ragEngineObj.Status.WorkerNodes,
ragEngineObj.Spec.Compute.InstanceType, "1")
if err != nil {
return nil, fmt.Errorf("failed to get SKU num GPUs: %v", err)
}

resourceReq = corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceName(resources.CapacityNvidiaGPU): resource.MustParse(skuNumGPUs),
},
Limits: corev1.ResourceList{
corev1.ResourceName(resources.CapacityNvidiaGPU): resource.MustParse(skuNumGPUs),
},
}

}
commands := utils.ShellCmd("python3 main.py")
// TODO: provide this image
image := "mcr.microsoft.com/aks/kaito/kaito-rag-service:0.0.1"
imagePullSecretRefs := []corev1.LocalObjectReference{}

depObj := manifests.GenerateRAGDeploymentManifest(ctx, ragEngineObj, revisionNum, image, imagePullSecretRefs, *ragEngineObj.Spec.Compute.Count, commands,
containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volumes, volumeMounts)

err := resources.CreateResource(ctx, depObj, kubeClient)
if client.IgnoreAlreadyExists(err) != nil {
return nil, err
}
return depObj, nil
}
60 changes: 60 additions & 0 deletions pkg/ragengine/controllers/preset-rag_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package controllers

import (
"context"
"os"
"strings"
"testing"

"github.com/kaito-project/kaito/pkg/utils/consts"
"github.com/kaito-project/kaito/pkg/utils/test"
"github.com/stretchr/testify/mock"
appsv1 "k8s.io/api/apps/v1"
)

func TestCreatePresetRAG(t *testing.T) {
test.RegisterTestModel()

testcases := map[string]struct {
nodeCount int
callMocks func(c *test.MockClient)
expectedCmd string
expectedGPUReq string
expectedImage string
expectedVolume string
}{
"test-rag-model": {
nodeCount: 1,
callMocks: func(c *test.MockClient) {
c.On("Create", mock.IsType(context.TODO()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil)
},
expectedCmd: "/bin/sh -c python3 main.py",
expectedImage: "mcr.microsoft.com/aks/kaito/kaito-rag-service:0.0.1",
},
}

for k, tc := range testcases {
t.Run(k, func(t *testing.T) {
os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)
mockClient := test.NewClient()
tc.callMocks(mockClient)

ragEngineObj := test.MockRAGEngineWithPreset
createdObject, _ := CreatePresetRAG(context.TODO(), ragEngineObj, "1", mockClient)

workloadCmd := strings.Join((createdObject.(*appsv1.Deployment)).Spec.Template.Spec.Containers[0].Command, " ")

if workloadCmd != tc.expectedCmd {
t.Errorf("%s: main cmdline is not expected, got %s, expected %s", k, workloadCmd, tc.expectedCmd)
}

image := (createdObject.(*appsv1.Deployment)).Spec.Template.Spec.Containers[0].Image

if image != tc.expectedImage {
t.Errorf("%s: image is not expected, got %s, expected %s", k, image, tc.expectedImage)
}
})
}
}
67 changes: 67 additions & 0 deletions pkg/ragengine/controllers/ragengine_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,78 @@ func (c *RAGEngineReconciler) ensureFinalizer(ctx context.Context, ragEngineObj
func (c *RAGEngineReconciler) addRAGEngine(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine) (reconcile.Result, error) {
err := c.applyRAGEngineResource(ctx, ragEngineObj)
if err != nil {
if updateErr := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEngineConditionTypeSucceeded, metav1.ConditionFalse,
"ragengineFailed", err.Error()); updateErr != nil {
klog.ErrorS(updateErr, "failed to update ragengine status", "ragengine", klog.KObj(ragEngineObj))
return reconcile.Result{}, updateErr
}
// if error is due to machine/nodeClaim instance types unavailability, stop reconcile.
if err.Error() == consts.ErrorInstanceTypesUnavailable {
return reconcile.Result{Requeue: false}, err
}
return reconcile.Result{}, err
}
if err = c.applyRAG(ctx, ragEngineObj); err != nil {
if updateErr := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEngineConditionTypeSucceeded, metav1.ConditionFalse,
"ragengineFailed", err.Error()); updateErr != nil {
klog.ErrorS(updateErr, "failed to update ragengine status", "ragengine", klog.KObj(ragEngineObj))
return reconcile.Result{}, updateErr
}
return reconcile.Result{}, err
}

if err = c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEngineConditionTypeSucceeded, metav1.ConditionTrue,
"ragengineSucceeded", "ragengine succeeds"); err != nil {
klog.ErrorS(err, "failed to update ragengine status", "ragengine", klog.KObj(ragEngineObj))
return reconcile.Result{}, err
}
return reconcile.Result{}, nil
}

func (c *RAGEngineReconciler) applyRAG(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine) error {
var err error
func() {

deployment := &appsv1.Deployment{}
revisionStr := ragEngineObj.Annotations[kaitov1alpha1.RAGEngineRevisionAnnotation]

if err = resources.GetResource(ctx, ragEngineObj.Name, ragEngineObj.Namespace, c.Client, deployment); err == nil {
klog.InfoS("An inference workload already exists for ragengine", "ragengine", klog.KObj(ragEngineObj))
return

} else if apierrors.IsNotFound(err) {
var workloadObj client.Object
// Need to create a new workload
workloadObj, err = CreatePresetRAG(ctx, ragEngineObj, revisionStr, c.Client)
if err != nil {
return
}
if err = resources.CheckResourceStatus(workloadObj, c.Client, time.Duration(10)*time.Minute); err != nil {
return
}
}

}()

if err != nil {
if updateErr := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGConditionTypeServiceStatus, metav1.ConditionFalse,
"RAGEngineServiceStatusFailed", err.Error()); updateErr != nil {
klog.ErrorS(updateErr, "failed to update ragengine status", "ragengine", klog.KObj(ragEngineObj))
return updateErr
} else {
return err
}
}

if err := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEneineConditionTypeServiceStatus, metav1.ConditionTrue,
"RAGEngineServiceSuccess", "Inference has been deployed successfully"); err != nil {
klog.ErrorS(err, "failed to update ragengine status", "ragengine", klog.KObj(ragEngineObj))
return err
}

return nil
}

func (c *RAGEngineReconciler) deleteRAGEngine(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine) (reconcile.Result, error) {
klog.InfoS("deleteRAGEngine", "ragengine", klog.KObj(ragEngineObj))
err := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEngineConditionTypeDeleting, metav1.ConditionTrue, "ragengineDeleted", "ragengine is being deleted")
Expand Down
Loading

0 comments on commit cafb947

Please sign in to comment.