Skip to content

Commit

Permalink
fix: workspace condition
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaansehgal99 committed Mar 21, 2024
1 parent 931b037 commit 107c501
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 51 deletions.
10 changes: 8 additions & 2 deletions api/v1alpha1/workspace_condition_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,14 @@ const (
// WorkspaceConditionTypeInferenceStatus is the state when Inference has been created.
WorkspaceConditionTypeInferenceStatus = ConditionType("InferenceReady")

// WorkspaceConditionTypeTuningStatus is the state when Tuning has been created.
WorkspaceConditionTypeTuningStatus = ConditionType("TuningReady")
// WorkspaceConditionTypeTuningStarted indicates that the tuning Job has been started.
WorkspaceConditionTypeTuningStarted = ConditionType("TuningStarted")

// WorkspaceConditionTypeTuningComplete indicates that the tuning Job has completed successfully.
WorkspaceConditionTypeTuningComplete = ConditionType("TuningComplete")

// WorkspaceConditionTypeTuningFailed indicates that the tuning Job has failed to complete.
WorkspaceConditionTypeTuningFailed = ConditionType("TuningFailed")

//WorkspaceConditionTypeDeleting is the Workspace state when starts to get deleted.
WorkspaceConditionTypeDeleting = ConditionType("WorkspaceDeleting")
Expand Down
8 changes: 4 additions & 4 deletions api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (*testModel) GetInferenceParameters() *model.PresetParam {
PerGPUMemoryRequirement: perGPUMemoryRequirement,
}
}
func (*testModel) GetTrainingParameters() *model.PresetParam {
func (*testModel) GetTuningParameters() *model.PresetParam {
return &model.PresetParam{
GPUCountRequirement: gpuCountRequirement,
TotalGPUMemoryRequirement: totalGPUMemoryRequirement,
Expand All @@ -38,7 +38,7 @@ func (*testModel) GetTrainingParameters() *model.PresetParam {
func (*testModel) SupportDistributedInference() bool {
return false
}
func (*testModel) SupportTraining() bool {
func (*testModel) SupportTuning() bool {
return true
}

Expand All @@ -52,7 +52,7 @@ func (*testModelPrivate) GetInferenceParameters() *model.PresetParam {
PerGPUMemoryRequirement: perGPUMemoryRequirement,
}
}
func (*testModelPrivate) GetTrainingParameters() *model.PresetParam {
func (*testModelPrivate) GetTuningParameters() *model.PresetParam {
return &model.PresetParam{
ImageAccessMode: "private",
GPUCountRequirement: gpuCountRequirement,
Expand All @@ -63,7 +63,7 @@ func (*testModelPrivate) GetTrainingParameters() *model.PresetParam {
func (*testModelPrivate) SupportDistributedInference() bool {
return false
}
func (*testModelPrivate) SupportTraining() bool {
func (*testModelPrivate) SupportTuning() bool {
return true
}

Expand Down
26 changes: 13 additions & 13 deletions pkg/controllers/workspace_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@ package controllers
import (
"context"
"fmt"
"github.com/azure/kaito/pkg/tuning"
"sort"
"strings"
"time"

appsv1 "k8s.io/api/apps/v1"
"k8s.io/utils/clock"
"github.com/azure/kaito/pkg/tuning"
batchv1 "k8s.io/api/batch/v1"

"github.com/aws/karpenter-core/pkg/apis/v1alpha5"
kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1"
Expand All @@ -22,13 +21,15 @@ import (
"github.com/azure/kaito/pkg/utils/plugin"
"github.com/go-logr/logr"
"github.com/samber/lo"
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/tools/record"
"k8s.io/klog/v2"
"k8s.io/utils/clock"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/controller"
Expand Down Expand Up @@ -442,30 +443,29 @@ func (c *WorkspaceReconciler) applyTuning(ctx context.Context, wObj *kaitov1alph
presetName := string(wObj.Tuning.Preset.Name)
model := plugin.KaitoModelRegister.MustGet(presetName)

trainingParam := model.GetTrainingParameters()

existingObj := &appsv1.Deployment{}
tuningParam := model.GetTuningParameters()
existingObj := &batchv1.Job{}
if err = resources.GetResource(ctx, wObj.Name, wObj.Namespace, c.Client, existingObj); err == nil {
klog.InfoS("A training workload already exists for workspace", "workspace", klog.KObj(wObj))
if err = resources.CheckResourceStatus(existingObj, c.Client, trainingParam.WorkloadTimeout); err != nil {
klog.InfoS("A tuning workload already exists for workspace", "workspace", klog.KObj(wObj))
if err = resources.CheckResourceStatus(existingObj, c.Client, tuningParam.WorkloadTimeout); err != nil {
return
}
} else if apierrors.IsNotFound(err) {
var workloadObj client.Object
// Need to create a new workload
workloadObj, err = tuning.CreatePresetTuning(ctx, wObj, trainingParam, c.Client)
workloadObj, err = tuning.CreatePresetTuning(ctx, wObj, tuningParam, c.Client)
if err != nil {
return
}
if err = resources.CheckResourceStatus(workloadObj, c.Client, trainingParam.WorkloadTimeout); err != nil {
if err = resources.CheckResourceStatus(workloadObj, c.Client, tuningParam.WorkloadTimeout); err != nil {
return
}
}
}
}()

if err != nil {
if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeTuningStatus, metav1.ConditionFalse,
if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeTuningFailed, metav1.ConditionFalse,
"WorkspaceTuningStatusFailed", err.Error()); updateErr != nil {
klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj))
return updateErr
Expand All @@ -475,8 +475,8 @@ func (c *WorkspaceReconciler) applyTuning(ctx context.Context, wObj *kaitov1alph
}
}

if err := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeTuningStatus, metav1.ConditionTrue,
"WorkspaceTuningStatusSuccess", "Tuning has been deployed successfully"); err != nil {
if err := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeTuningStarted, metav1.ConditionTrue,
"WorkspaceTuningStatusStarted", "Tuning has been deployed successfully"); err != nil {
klog.ErrorS(err, "failed to update workspace status", "workspace", klog.KObj(wObj))
return err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/model/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import (

type Model interface {
GetInferenceParameters() *PresetParam
GetTrainingParameters() *PresetParam
GetTuningParameters() *PresetParam
SupportDistributedInference() bool //If true, the model workload will be a StatefulSet, using the torch elastic runtime framework.
SupportTraining() bool
SupportTuning() bool
}

// PresetParam defines the preset inference parameters for a model.
Expand Down
8 changes: 4 additions & 4 deletions pkg/utils/testModel.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func (*testModel) GetInferenceParameters() *model.PresetParam {
WorkloadTimeout: time.Duration(30) * time.Minute,
}
}
func (*testModel) GetTrainingParameters() *model.PresetParam {
func (*testModel) GetTuningParameters() *model.PresetParam {
return &model.PresetParam{
GPUCountRequirement: "1",
WorkloadTimeout: time.Duration(30) * time.Minute,
Expand All @@ -27,7 +27,7 @@ func (*testModel) GetTrainingParameters() *model.PresetParam {
func (*testModel) SupportDistributedInference() bool {
return false
}
func (*testModel) SupportTraining() bool {
func (*testModel) SupportTuning() bool {
return true
}

Expand All @@ -39,7 +39,7 @@ func (*testDistributedModel) GetInferenceParameters() *model.PresetParam {
WorkloadTimeout: time.Duration(30) * time.Minute,
}
}
func (*testDistributedModel) GetTrainingParameters() *model.PresetParam {
func (*testDistributedModel) GetTuningParameters() *model.PresetParam {
return &model.PresetParam{
GPUCountRequirement: "1",
WorkloadTimeout: time.Duration(30) * time.Minute,
Expand All @@ -48,7 +48,7 @@ func (*testDistributedModel) GetTrainingParameters() *model.PresetParam {
func (*testDistributedModel) SupportDistributedInference() bool {
return true
}
func (*testDistributedModel) SupportTraining() bool {
func (*testDistributedModel) SupportTuning() bool {
return true
}

Expand Down
16 changes: 8 additions & 8 deletions presets/models/falcon/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (*falcon7b) GetInferenceParameters() *model.PresetParam {
Tag: PresetFalconTagMap["Falcon7B"],
}
}
func (*falcon7b) GetTrainingParameters() *model.PresetParam {
func (*falcon7b) GetTuningParameters() *model.PresetParam {
return &model.PresetParam{
ModelFamilyName: "Falcon",
ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic),
Expand All @@ -88,7 +88,7 @@ func (*falcon7b) GetTrainingParameters() *model.PresetParam {
func (*falcon7b) SupportDistributedInference() bool {
return false
}
func (*falcon7b) SupportTraining() bool {
func (*falcon7b) SupportTuning() bool {
return true
}

Expand All @@ -112,13 +112,13 @@ func (*falcon7bInst) GetInferenceParameters() *model.PresetParam {
}

}
func (*falcon7bInst) GetTrainingParameters() *model.PresetParam {
func (*falcon7bInst) GetTuningParameters() *model.PresetParam {
return nil // It is not recommended/ideal to further fine-tune instruct models - Already been fine-tuned
}
func (*falcon7bInst) SupportDistributedInference() bool {
return false
}
func (*falcon7bInst) SupportTraining() bool {
func (*falcon7bInst) SupportTuning() bool {
return false
}

Expand All @@ -142,7 +142,7 @@ func (*falcon40b) GetInferenceParameters() *model.PresetParam {
}

}
func (*falcon40b) GetTrainingParameters() *model.PresetParam {
func (*falcon40b) GetTuningParameters() *model.PresetParam {
return &model.PresetParam{
ModelFamilyName: "Falcon",
ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic),
Expand All @@ -160,7 +160,7 @@ func (*falcon40b) GetTrainingParameters() *model.PresetParam {
func (*falcon40b) SupportDistributedInference() bool {
return false
}
func (*falcon40b) SupportTraining() bool {
func (*falcon40b) SupportTuning() bool {
return true
}

Expand All @@ -183,12 +183,12 @@ func (*falcon40bInst) GetInferenceParameters() *model.PresetParam {
Tag: PresetFalconTagMap["Falcon40BInstruct"],
}
}
func (*falcon40bInst) GetTrainingParameters() *model.PresetParam {
func (*falcon40bInst) GetTuningParameters() *model.PresetParam {
return nil // It is not recommended/ideal to further fine-tune instruct models - Already been fine-tuned
}
func (*falcon40bInst) SupportDistributedInference() bool {
return false
}
func (*falcon40bInst) SupportTraining() bool {
func (*falcon40bInst) SupportTuning() bool {
return false
}
12 changes: 6 additions & 6 deletions presets/models/llama2/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ func (*llama2Text7b) GetInferenceParameters() *model.PresetParam {
}

}
func (*llama2Text7b) GetTrainingParameters() *model.PresetParam {
func (*llama2Text7b) GetTuningParameters() *model.PresetParam {
return nil // Currently doesn't support fine-tuning
}
func (*llama2Text7b) SupportDistributedInference() bool {
return false
}
func (*llama2Text7b) SupportTraining() bool {
func (*llama2Text7b) SupportTuning() bool {
return false
}

Expand All @@ -87,13 +87,13 @@ func (*llama2Text13b) GetInferenceParameters() *model.PresetParam {
// Tag: llama has private image access mode. The image tag is determined by the user.
}
}
func (*llama2Text13b) GetTrainingParameters() *model.PresetParam {
func (*llama2Text13b) GetTuningParameters() *model.PresetParam {
return nil // Currently doesn't support fine-tuning
}
func (*llama2Text13b) SupportDistributedInference() bool {
return true
}
func (*llama2Text13b) SupportTraining() bool {
func (*llama2Text13b) SupportTuning() bool {
return false
}

Expand All @@ -118,12 +118,12 @@ func (*llama2Text70b) GetInferenceParameters() *model.PresetParam {
// Tag: llama has private image access mode. The image tag is determined by the user.
}
}
func (*llama2Text70b) GetTrainingParameters() *model.PresetParam {
func (*llama2Text70b) GetTuningParameters() *model.PresetParam {
return nil // Currently doesn't support fine-tuning
}
func (*llama2Text70b) SupportDistributedInference() bool {
return true
}
func (*llama2Text70b) SupportTraining() bool {
func (*llama2Text70b) SupportTuning() bool {
return false
}
12 changes: 6 additions & 6 deletions presets/models/llama2chat/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ func (*llama2Chat7b) GetInferenceParameters() *model.PresetParam {
// Tag: llama has private image access mode. The image tag is determined by the user.
}
}
func (*llama2Chat7b) GetTrainingParameters() *model.PresetParam {
func (*llama2Chat7b) GetTuningParameters() *model.PresetParam {
return nil // Currently doesn't support fine-tuning
}
func (*llama2Chat7b) SupportDistributedInference() bool {
return false
}
func (*llama2Chat7b) SupportTraining() bool {
func (*llama2Chat7b) SupportTuning() bool {
return false
}

Expand All @@ -86,13 +86,13 @@ func (*llama2Chat13b) GetInferenceParameters() *model.PresetParam {
// Tag: llama has private image access mode. The image tag is determined by the user.
}
}
func (*llama2Chat13b) GetTrainingParameters() *model.PresetParam {
func (*llama2Chat13b) GetTuningParameters() *model.PresetParam {
return nil // Currently doesn't support fine-tuning
}
func (*llama2Chat13b) SupportDistributedInference() bool {
return true
}
func (*llama2Chat13b) SupportTraining() bool {
func (*llama2Chat13b) SupportTuning() bool {
return false
}

Expand All @@ -117,12 +117,12 @@ func (*llama2Chat70b) GetInferenceParameters() *model.PresetParam {
// Tag: llama has private image access mode. The image tag is determined by the user.
}
}
func (*llama2Chat70b) GetTrainingParameters() *model.PresetParam {
func (*llama2Chat70b) GetTuningParameters() *model.PresetParam {
return nil // Currently doesn't support fine-tuning
}
func (*llama2Chat70b) SupportDistributedInference() bool {
return true
}
func (*llama2Chat70b) SupportTraining() bool {
func (*llama2Chat70b) SupportTuning() bool {
return false
}
8 changes: 4 additions & 4 deletions presets/models/mistral/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (*mistral7b) GetInferenceParameters() *model.PresetParam {
}

}
func (*mistral7b) GetTrainingParameters() *model.PresetParam {
func (*mistral7b) GetTuningParameters() *model.PresetParam {
return &model.PresetParam{
ModelFamilyName: "Mistral",
ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic),
Expand All @@ -77,7 +77,7 @@ func (*mistral7b) GetTrainingParameters() *model.PresetParam {
func (*mistral7b) SupportDistributedInference() bool {
return false
}
func (*mistral7b) SupportTraining() bool {
func (*mistral7b) SupportTuning() bool {
return true
}

Expand All @@ -101,12 +101,12 @@ func (*mistral7bInst) GetInferenceParameters() *model.PresetParam {
}

}
func (*mistral7bInst) GetTrainingParameters() *model.PresetParam {
func (*mistral7bInst) GetTuningParameters() *model.PresetParam {
return nil // It is not recommended/ideal to further fine-tune instruct models - Already been fine-tuned
}
func (*mistral7bInst) SupportDistributedInference() bool {
return false
}
func (*mistral7bInst) SupportTraining() bool {
func (*mistral7bInst) SupportTuning() bool {
return false
}
4 changes: 2 additions & 2 deletions presets/models/phi/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (*phi2) GetInferenceParameters() *model.PresetParam {
Tag: PresetPhiTagMap["Phi2"],
}
}
func (*phi2) GetTrainingParameters() *model.PresetParam {
func (*phi2) GetTuningParameters() *model.PresetParam {
return &model.PresetParam{
ModelFamilyName: "Phi",
ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic),
Expand All @@ -69,6 +69,6 @@ func (*phi2) GetTrainingParameters() *model.PresetParam {
func (*phi2) SupportDistributedInference() bool {
return false
}
func (*phi2) SupportTraining() bool {
func (*phi2) SupportTuning() bool {
return true
}

0 comments on commit 107c501

Please sign in to comment.