diff --git a/pkg/core/handler/search/search.go b/pkg/core/handler/search/search.go index 35bb3623..df91bbdb 100644 --- a/pkg/core/handler/search/search.go +++ b/pkg/core/handler/search/search.go @@ -35,7 +35,7 @@ import ( // @Tags search // @Produce json // @Param query query string true "The query to use for search. Required" -// @Param pattern query string true "The search pattern. Can be either sql or dsl. Required" +// @Param pattern query string true "The search pattern. Can be either sql, dsl or nl. Required" // @Param pageSize query string false "The size of the page. Default to 10" // @Param page query string false "The current page to fetch. Default to 1" // @Success 200 {array} runtime.Object "Array of runtime.Object" @@ -66,9 +66,42 @@ func SearchForResource(searchMgr *search.SearchManager, aiMgr *ai.AIManager, sea searchPage = 1 } + query := searchQuery + + if searchPattern == storage.NLPatternType { + //logger.Info(searchQuery) + res, err := aiMgr.ConvertTextToSQL(searchQuery) + if err != nil { + handler.FailureRender(ctx, w, r, err) + return + } + searchQuery = res + } + + //logger.Info(searchQuery) logger.Info("Searching for resources...", "page", searchPage, "pageSize", searchPageSize) res, err := searchStorage.Search(ctx, searchQuery, searchPattern, &storage.Pagination{Page: searchPage, PageSize: searchPageSize}) + if err != nil { + if searchPattern == storage.NLPatternType { + //logger.Info(err.Error()) + fixedQuery, fixErr := aiMgr.FixSQL(query, searchQuery, err.Error()) + if fixErr != nil { + handler.FailureRender(ctx, w, r, err) + return + } + searchQuery = fixedQuery + res, err = searchStorage.Search(ctx, searchQuery, searchPattern, &storage.Pagination{Page: searchPage, PageSize: searchPageSize}) + if err != nil { + handler.FailureRender(ctx, w, r, err) + return + } + } else { + handler.FailureRender(ctx, w, r, err) + return + } + } + if err != nil { handler.FailureRender(ctx, w, r, err) return @@ -83,6 +116,7 @@ func SearchForResource(searchMgr *search.SearchManager, aiMgr *ai.AIManager, sea Object: unObj, }) } + rt.SQLQuery = searchQuery rt.Total = res.Total rt.CurrentPage = searchPage rt.PageSize = searchPageSize diff --git a/pkg/core/manager/ai/search.go b/pkg/core/manager/ai/search.go new file mode 100644 index 00000000..4c814518 --- /dev/null +++ b/pkg/core/manager/ai/search.go @@ -0,0 +1,43 @@ +// Copyright The Karpor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ai + +import ( + "context" + "fmt" + "github.com/KusionStack/karpor/pkg/infra/ai" +) + +// ConvertTextToSQL converts natural language text to an SQL query +func (a *AIManager) ConvertTextToSQL(query string) (string, error) { + servicePrompt := ai.ServicePromptMap[ai.Text2sqlType] + prompt := fmt.Sprintf(servicePrompt, query) + res, err := a.client.Generate(context.Background(), prompt) + if err != nil { + return "", err + } + return ExtractSelectSQL(res), nil +} + +// FixSQL fix the error SQL +func (a *AIManager) FixSQL(sql string, query string, error string) (string, error) { + servicePrompt := ai.ServicePromptMap[ai.SqlFixType] + prompt := fmt.Sprintf(servicePrompt, query, sql, error) + res, err := a.client.Generate(context.Background(), prompt) + if err != nil { + return "", err + } + return ExtractSelectSQL(res), nil +} diff --git a/pkg/core/manager/ai/util.go b/pkg/core/manager/ai/util.go new file mode 100644 index 00000000..effe2be3 --- /dev/null +++ b/pkg/core/manager/ai/util.go @@ -0,0 +1,24 @@ +// Copyright The Karpor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ai + +import "regexp" + +// ExtractSelectSQL extracts SQL statements that start with "SELECT * FROM" +func ExtractSelectSQL(sql string) string { + res := regexp.MustCompile(`(?i)SELECT \* FROM [^;]+`) + match := res.FindString(sql) + return match +} diff --git a/pkg/core/manager/ai/util_test.go b/pkg/core/manager/ai/util_test.go new file mode 100644 index 00000000..d356d0b4 --- /dev/null +++ b/pkg/core/manager/ai/util_test.go @@ -0,0 +1,44 @@ +// Copyright The Karpor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ai + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +// TestExtractSelectSQL tests the correctness of the ExtractSelectSQL function. +func TestExtractSelectSQL(t *testing.T) { + testCases := []struct { + name string + sql string + expected string + }{ + { + name: "NormalCase", + sql: "Q: 所有kind=namespace " + + "Schema_links: [kind, namespace] " + + "SQL: select * from resources where kind='namespace';", + expected: "select * from resources where kind='namespace'", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actual := ExtractSelectSQL(tc.sql) + require.Equal(t, tc.expected, actual) + }) + } +} diff --git a/pkg/core/manager/search/types.go b/pkg/core/manager/search/types.go index a6cc5c99..0117e73e 100644 --- a/pkg/core/manager/search/types.go +++ b/pkg/core/manager/search/types.go @@ -34,6 +34,7 @@ type UniResource struct { type UniResourceList struct { metav1.TypeMeta Items []UniResource `json:"items"` + SQLQuery string `json:"sqlQuery"` Total int `json:"total"` CurrentPage int `json:"currentPage"` PageSize int `json:"pageSize"` diff --git a/pkg/infra/ai/azureopenai.go b/pkg/infra/ai/azureopenai.go index 37f9762e..d48aae38 100644 --- a/pkg/infra/ai/azureopenai.go +++ b/pkg/infra/ai/azureopenai.go @@ -47,16 +47,11 @@ func (c *AzureAIClient) Configure(cfg AIConfig) error { return nil } -func (c *AzureAIClient) Generate(ctx context.Context, prompt string, serviceType string) (string, error) { - servicePrompt := ServicePromptMap[serviceType] +func (c *AzureAIClient) Generate(ctx context.Context, prompt string) (string, error) { resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ Model: c.model, Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleSystem, - Content: servicePrompt, - }, { Role: openai.ChatMessageRoleUser, Content: prompt, diff --git a/pkg/infra/ai/huggingface.go b/pkg/infra/ai/huggingface.go index edd5e6eb..b61e28fe 100644 --- a/pkg/infra/ai/huggingface.go +++ b/pkg/infra/ai/huggingface.go @@ -34,7 +34,7 @@ func (c *HuggingfaceClient) Configure(cfg AIConfig) error { return nil } -func (c *HuggingfaceClient) Generate(ctx context.Context, prompt string, serviceType string) (string, error) { +func (c *HuggingfaceClient) Generate(ctx context.Context, prompt string) (string, error) { resp, err := c.client.TextGeneration(ctx, &huggingface.TextGenerationRequest{ Inputs: prompt, Parameters: huggingface.TextGenerationParameters{ diff --git a/pkg/infra/ai/openai.go b/pkg/infra/ai/openai.go index 82ef8c14..8de36a2d 100644 --- a/pkg/infra/ai/openai.go +++ b/pkg/infra/ai/openai.go @@ -48,16 +48,10 @@ func (c *OpenAIClient) Configure(cfg AIConfig) error { return nil } -func (c *OpenAIClient) Generate(ctx context.Context, prompt string, serviceType string) (string, error) { - servicePrompt := ServicePromptMap[serviceType] - +func (c *OpenAIClient) Generate(ctx context.Context, prompt string) (string, error) { resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ Model: c.model, Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleSystem, - Content: servicePrompt, - }, { Role: openai.ChatMessageRoleUser, Content: prompt, diff --git a/pkg/infra/ai/prompts.go b/pkg/infra/ai/prompts.go index 7d4479a5..06feceae 100644 --- a/pkg/infra/ai/prompts.go +++ b/pkg/infra/ai/prompts.go @@ -16,10 +16,75 @@ package ai const ( default_prompt = "You are a helpful assistant." - search_prompt = "You are an AI specialized in writing SQL queries. Please provide the SQL statement as a formatted string." + + text2sql_prompt = ` + You are an AI specialized in writing SQL queries. + Please convert the text %s to sql. + The output tokens only need to give the SQL first, the other thought process please do not give. + The SQL should begin with "select * from" and end with ";". + + 1. The database now only supports one table resources. + + Table resources, columns = [cluster, apiVersion, kind, + namespace, name, creationTimestamp, deletionTimestamp, ownerReferences, + resourceVersion, labels.[key], annotations.[key], content] + + 2. find the schema_links for generating SQL queries for each question based on the database schema. + + Follow are some examples. + + Q: find the kind which is not equal to pod + A: Let’s think step by step. In the question "find the kind column which is not equal to pod", we are asked: + "find the kind" so we need column = [kind] + Based on the columns, the set of possible cell values are = [pod]. + So the Schema_links are: + Schema_links: [kind, pod] + + Q: find the kind Deployment which created before January 1, 2024, at 18:00:00 + A: Let’s think step by step. In the question "find the kind Deployment which created before January 1, 2024, at 18:00:00", we are asked: + "find the kind Deployment" so we need column = [kind] + "created before" so we need column = [creationTimestamp] + Based on the columns, the set of possible cell values are = [Deployment, 2024-01-01T18:00:00Z]. + So the Schema_links are: + Schema_links: [kind, creationTimestamp, Deployment, 2024-01-01T18:00:00Z] + + 3. Use the the schema links to generate the SQL queries for each of the questions. + + Follow are some examples. + + Q: find the kind which is not equal to pod + Schema_links: [kind, pod] + SQL: select * from resources where kind!='Pod'; + + Q: find the kind Deployment which created before January 1, 2024, at 18:00:00 + Schema_links: [kind, creationTimestamp, Deployment, 2024-01-01T18:00:00Z] + SQL: select * from resources where kind='Deployment' and creationTimestamp < '2024-01-01T18:00:00Z'; + + Q: find the namespace which does not contain banan + Schema_links: [namespace, banan] + SQL: select * from resources where namespace notlike 'banan_'; + + Please convert the text to sql. + ` + + sql_fix_prompt = ` + You are an AI specialized in writing SQL queries. + Please convert the text %s to sql. + The SQL should begin with "select * from". + + The database now only supports one table resources. + + Table resources, columns = [cluster, apiVersion, kind, + namespace, name, creationTimestamp, deletionTimestamp, ownerReferences, + resourceVersion, labels.[key], annotations.[key], content] + + After we executed SQL %s, we observed the following error %s. + Please fix the SQL. + ` ) var ServicePromptMap = map[string]string{ - "default": default_prompt, - "search": search_prompt, + "default": default_prompt, + "Text2sql": text2sql_prompt, + "SqlFix": sql_fix_prompt, } diff --git a/pkg/infra/ai/types.go b/pkg/infra/ai/types.go index 0d2025eb..4437506e 100644 --- a/pkg/infra/ai/types.go +++ b/pkg/infra/ai/types.go @@ -25,6 +25,11 @@ const ( OpenAIProvider = "openai" ) +const ( + Text2sqlType = "Text2sql" + SqlFixType = "SqlFix" +) + var clients = map[string]AIProvider{ AzureProvider: &AzureAIClient{}, HuggingFaceProvider: &HuggingfaceClient{}, @@ -37,7 +42,7 @@ type AIProvider interface { Configure(config AIConfig) error // Generate generates a response from the AI service based on // the provided prompt and service type. - Generate(ctx context.Context, prompt string, serviceType string) (string, error) + Generate(ctx context.Context, prompt string) (string, error) } // AIConfig represents the configuration settings for an AI client. diff --git a/pkg/infra/search/storage/elasticsearch/search.go b/pkg/infra/search/storage/elasticsearch/search.go index 2118a68d..86245eab 100644 --- a/pkg/infra/search/storage/elasticsearch/search.go +++ b/pkg/infra/search/storage/elasticsearch/search.go @@ -47,7 +47,7 @@ func (s *Storage) Search(ctx context.Context, queryStr string, patternType strin if err != nil { return nil, errors.Wrap(err, "search by DSL failed") } - case storage.SQLPatternType: + case storage.SQLPatternType, storage.NLPatternType: sr, err = s.searchBySQL(ctx, queryStr, pagination) if err != nil { return nil, errors.Wrap(err, "search by SQL failed") diff --git a/pkg/infra/search/storage/types.go b/pkg/infra/search/storage/types.go index 7cb6c73b..c90a44c1 100644 --- a/pkg/infra/search/storage/types.go +++ b/pkg/infra/search/storage/types.go @@ -31,6 +31,7 @@ import ( const ( Equals = "=" + NLPatternType = "nl" DSLPatternType = "dsl" SQLPatternType = "sql" )