From a03b89e764371e0e20270d212618b321ecaf7afc Mon Sep 17 00:00:00 2001 From: popsiclexu Date: Thu, 15 Aug 2024 15:49:22 +0800 Subject: [PATCH] feat: add custom restful backend for complex scenarios (e.g, rag) Signed-off-by: popsiclexu Signed-off-by: popsiclexu Signed-off-by: popsiclexu --- README.md | 1 + pkg/ai/customrest.go | 147 +++++++++++++++++++++++++++++++++++++++ pkg/ai/iai.go | 4 +- pkg/ai/prompts.go | 2 + pkg/analysis/analysis.go | 3 + 5 files changed, 156 insertions(+), 1 deletion(-) create mode 100644 pkg/ai/customrest.go diff --git a/README.md b/README.md index 08725f8435..d9accfa1f5 100644 --- a/README.md +++ b/README.md @@ -343,6 +343,7 @@ Unused: > noopai > googlevertexai > watsonxai +> customrest ``` For detailed documentation on how to configure and use each provider see [here](https://docs.k8sgpt.ai/reference/providers/backend/). diff --git a/pkg/ai/customrest.go b/pkg/ai/customrest.go new file mode 100644 index 0000000000..c72d0b3176 --- /dev/null +++ b/pkg/ai/customrest.go @@ -0,0 +1,147 @@ +package ai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +const CustomRestClientName = "customrest" + +type CustomRestClient struct { + nopCloser + client *http.Client + base *url.URL + token string + model string + temperature float32 + topP float32 + topK int32 +} + +type CustomRestRequest struct { + Model string `json:"model"` + + // Prompt is the textual prompt to send to the model. + Prompt string `json:"prompt"` + + // Options lists model-specific options. For example, temperature can be + // set through this field, if the model supports it. + Options map[string]interface{} `json:"options"` +} + +type CustomRestResponse struct { + // Model is the model name that generated the response. + Model string `json:"model"` + + // CreatedAt is the timestamp of the response. + CreatedAt time.Time `json:"created_at"` + + // Response is the textual response itself. + Response string `json:"response"` +} + +func (c *CustomRestClient) Configure(config IAIConfig) error { + baseURL := config.GetBaseURL() + if baseURL == "" { + baseURL = defaultBaseURL + } + c.token = config.GetPassword() + baseClientURL, err := url.Parse(baseURL) + if err != nil { + return err + } + c.base = baseClientURL + + proxyEndpoint := config.GetProxyEndpoint() + c.client = http.DefaultClient + if proxyEndpoint != "" { + proxyUrl, err := url.Parse(proxyEndpoint) + if err != nil { + return err + } + transport := &http.Transport{ + Proxy: http.ProxyURL(proxyUrl), + } + + c.client = &http.Client{ + Transport: transport, + } + } + + c.model = config.GetModel() + if c.model == "" { + c.model = defaultModel + } + c.temperature = config.GetTemperature() + c.topP = config.GetTopP() + c.topK = config.GetTopK() + return nil +} + +func (c *CustomRestClient) GetCompletion(ctx context.Context, prompt string) (string, error) { + var promptDetail struct { + Language string `json:"language,omitempty"` + Message string `json:"message"` + Prompt string `json:"prompt,omitempty"` + } + prompt = strings.ReplaceAll(prompt, "\n", "\\n") + prompt = strings.ReplaceAll(prompt, "\t", "\\t") + if err := json.Unmarshal([]byte(prompt), &promptDetail); err != nil { + return "", err + } + generateRequest := &CustomRestRequest{ + Model: c.model, + Prompt: promptDetail.Prompt, + Options: map[string]interface{}{ + "temperature": c.temperature, + "top_p": c.topP, + "top_k": c.topK, + "message": promptDetail.Message, + "language": promptDetail.Language, + }, + } + bts, err := json.Marshal(generateRequest) + if err != nil { + return "", err + } + request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base.String(), bytes.NewBuffer(bts)) + if err != nil { + return "", err + } + if c.token != "" { + request.Header.Set("Authorization", "Bearer "+c.token) + } + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Accept", "application/x-ndjson") + response, err := c.client.Do(request) + if err != nil { + return "", err + } + defer response.Body.Close() + + resBody, err := io.ReadAll(response.Body) + if err != nil { + return "", fmt.Errorf("could not read response body: %w", err) + } + + if response.StatusCode >= http.StatusBadRequest { + return "", fmt.Errorf("Request Error, StatusCode: %d, ErrorMessage: %s", response.StatusCode, resBody) + } + + var resp CustomRestResponse + if err := json.Unmarshal(resBody, &resp); err != nil { + return "", err + } + return resp.Response, nil +} + +func (c *CustomRestClient) GetName() string { + return CustomRestClientName +} diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index 38c8500346..b94373a17f 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -33,6 +33,7 @@ var ( &GoogleVertexAIClient{}, &OCIGenAIClient{}, &WatsonxAIClient{}, + &CustomRestClient{}, } Backends = []string{ openAIClientName, @@ -48,6 +49,7 @@ var ( googleVertexAIClientName, ociClientName, watsonxAIClientName, + CustomRestClientName, } ) @@ -181,7 +183,7 @@ func (p *AIProvider) GetCustomHeaders() []http.Header { return p.CustomHeaders } -var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "watsonxai"} +var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "watsonxai", "customrest"} func NeedPassword(backend string) bool { for _, b := range passwordlessProviders { diff --git a/pkg/ai/prompts.go b/pkg/ai/prompts.go index e9cd480d1d..ca8c7625cc 100644 --- a/pkg/ai/prompts.go +++ b/pkg/ai/prompts.go @@ -58,9 +58,11 @@ const ( Solution: {kubectl command} ` + raw_promt = `{"language": "%s","message": "%s","prompt": "%s"}` ) var PromptMap = map[string]string{ + "raw": raw_promt, "default": default_prompt, "VulnerabilityReport": trivy_vuln_prompt, // for Trivy integration, the key should match `Result.Kind` in pkg/common/types.go "ConfigAuditReport": trivy_conf_prompt, diff --git a/pkg/analysis/analysis.go b/pkg/analysis/analysis.go index 1a28c9846a..65b34b8e24 100644 --- a/pkg/analysis/analysis.go +++ b/pkg/analysis/analysis.go @@ -397,6 +397,9 @@ func (a *Analysis) getAIResultForSanitizedFailures(texts []string, promptTmpl st // Process template. prompt := fmt.Sprintf(strings.TrimSpace(promptTmpl), a.Language, inputKey) + if a.AIClient.GetName() == ai.CustomRestClientName { + prompt = fmt.Sprintf(ai.PromptMap["raw"], a.Language, inputKey, prompt) + } response, err := a.AIClient.GetCompletion(a.Context, prompt) if err != nil { return "", err