Skip to content

Commit

Permalink
feat: support natural language as a search method (#556)
Browse files Browse the repository at this point in the history
## What type of PR is this?

/kind feature

## What this PR does / why we need it:

Support natural language as a search method for kubernetes resources.

## Which issue(s) this PR fixes:

Fixes #452

---------

Co-authored-by: ruquanzhao <[email protected]>
  • Loading branch information
jueli12 and ruquanzhao authored Sep 3, 2024
1 parent bc2b358 commit 0ddb27a
Show file tree
Hide file tree
Showing 12 changed files with 226 additions and 20 deletions.
36 changes: 35 additions & 1 deletion pkg/core/handler/search/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (
// @Tags search
// @Produce json
// @Param query query string true "The query to use for search. Required"
// @Param pattern query string true "The search pattern. Can be either sql or dsl. Required"
// @Param pattern query string true "The search pattern. Can be either sql, dsl or nl. Required"
// @Param pageSize query string false "The size of the page. Default to 10"
// @Param page query string false "The current page to fetch. Default to 1"
// @Success 200 {array} runtime.Object "Array of runtime.Object"
Expand Down Expand Up @@ -66,9 +66,42 @@ func SearchForResource(searchMgr *search.SearchManager, aiMgr *ai.AIManager, sea
searchPage = 1
}

query := searchQuery

if searchPattern == storage.NLPatternType {
//logger.Info(searchQuery)
res, err := aiMgr.ConvertTextToSQL(searchQuery)
if err != nil {
handler.FailureRender(ctx, w, r, err)
return
}
searchQuery = res
}

//logger.Info(searchQuery)
logger.Info("Searching for resources...", "page", searchPage, "pageSize", searchPageSize)

res, err := searchStorage.Search(ctx, searchQuery, searchPattern, &storage.Pagination{Page: searchPage, PageSize: searchPageSize})
if err != nil {
if searchPattern == storage.NLPatternType {
//logger.Info(err.Error())
fixedQuery, fixErr := aiMgr.FixSQL(query, searchQuery, err.Error())
if fixErr != nil {
handler.FailureRender(ctx, w, r, err)
return
}
searchQuery = fixedQuery
res, err = searchStorage.Search(ctx, searchQuery, searchPattern, &storage.Pagination{Page: searchPage, PageSize: searchPageSize})
if err != nil {
handler.FailureRender(ctx, w, r, err)
return
}
} else {
handler.FailureRender(ctx, w, r, err)
return
}
}

if err != nil {
handler.FailureRender(ctx, w, r, err)
return
Expand All @@ -83,6 +116,7 @@ func SearchForResource(searchMgr *search.SearchManager, aiMgr *ai.AIManager, sea
Object: unObj,
})
}
rt.SQLQuery = searchQuery
rt.Total = res.Total
rt.CurrentPage = searchPage
rt.PageSize = searchPageSize
Expand Down
43 changes: 43 additions & 0 deletions pkg/core/manager/ai/search.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright The Karpor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ai

import (
"context"
"fmt"
"github.com/KusionStack/karpor/pkg/infra/ai"
)

// ConvertTextToSQL converts natural language text to an SQL query
func (a *AIManager) ConvertTextToSQL(query string) (string, error) {
servicePrompt := ai.ServicePromptMap[ai.Text2sqlType]
prompt := fmt.Sprintf(servicePrompt, query)
res, err := a.client.Generate(context.Background(), prompt)
if err != nil {
return "", err
}
return ExtractSelectSQL(res), nil
}

// FixSQL fix the error SQL
func (a *AIManager) FixSQL(sql string, query string, error string) (string, error) {
servicePrompt := ai.ServicePromptMap[ai.SqlFixType]
prompt := fmt.Sprintf(servicePrompt, query, sql, error)
res, err := a.client.Generate(context.Background(), prompt)
if err != nil {
return "", err
}
return ExtractSelectSQL(res), nil
}
24 changes: 24 additions & 0 deletions pkg/core/manager/ai/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright The Karpor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ai

import "regexp"

// ExtractSelectSQL extracts SQL statements that start with "SELECT * FROM"
func ExtractSelectSQL(sql string) string {
res := regexp.MustCompile(`(?i)SELECT \* FROM [^;]+`)
match := res.FindString(sql)
return match
}
44 changes: 44 additions & 0 deletions pkg/core/manager/ai/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright The Karpor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ai

import (
"github.com/stretchr/testify/require"
"testing"
)

// TestExtractSelectSQL tests the correctness of the ExtractSelectSQL function.
func TestExtractSelectSQL(t *testing.T) {
testCases := []struct {
name string
sql string
expected string
}{
{
name: "NormalCase",
sql: "Q: 所有kind=namespace " +
"Schema_links: [kind, namespace] " +
"SQL: select * from resources where kind='namespace';",
expected: "select * from resources where kind='namespace'",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actual := ExtractSelectSQL(tc.sql)
require.Equal(t, tc.expected, actual)
})
}
}
1 change: 1 addition & 0 deletions pkg/core/manager/search/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type UniResource struct {
type UniResourceList struct {
metav1.TypeMeta
Items []UniResource `json:"items"`
SQLQuery string `json:"sqlQuery"`
Total int `json:"total"`
CurrentPage int `json:"currentPage"`
PageSize int `json:"pageSize"`
Expand Down
7 changes: 1 addition & 6 deletions pkg/infra/ai/azureopenai.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,11 @@ func (c *AzureAIClient) Configure(cfg AIConfig) error {
return nil
}

func (c *AzureAIClient) Generate(ctx context.Context, prompt string, serviceType string) (string, error) {
servicePrompt := ServicePromptMap[serviceType]
func (c *AzureAIClient) Generate(ctx context.Context, prompt string) (string, error) {

resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
Model: c.model,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: servicePrompt,
},
{
Role: openai.ChatMessageRoleUser,
Content: prompt,
Expand Down
2 changes: 1 addition & 1 deletion pkg/infra/ai/huggingface.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (c *HuggingfaceClient) Configure(cfg AIConfig) error {
return nil
}

func (c *HuggingfaceClient) Generate(ctx context.Context, prompt string, serviceType string) (string, error) {
func (c *HuggingfaceClient) Generate(ctx context.Context, prompt string) (string, error) {
resp, err := c.client.TextGeneration(ctx, &huggingface.TextGenerationRequest{
Inputs: prompt,
Parameters: huggingface.TextGenerationParameters{
Expand Down
8 changes: 1 addition & 7 deletions pkg/infra/ai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,10 @@ func (c *OpenAIClient) Configure(cfg AIConfig) error {
return nil
}

func (c *OpenAIClient) Generate(ctx context.Context, prompt string, serviceType string) (string, error) {
servicePrompt := ServicePromptMap[serviceType]

func (c *OpenAIClient) Generate(ctx context.Context, prompt string) (string, error) {
resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
Model: c.model,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: servicePrompt,
},
{
Role: openai.ChatMessageRoleUser,
Content: prompt,
Expand Down
71 changes: 68 additions & 3 deletions pkg/infra/ai/prompts.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,75 @@ package ai

const (
default_prompt = "You are a helpful assistant."
search_prompt = "You are an AI specialized in writing SQL queries. Please provide the SQL statement as a formatted string."

text2sql_prompt = `
You are an AI specialized in writing SQL queries.
Please convert the text %s to sql.
The output tokens only need to give the SQL first, the other thought process please do not give.
The SQL should begin with "select * from" and end with ";".
1. The database now only supports one table resources.
Table resources, columns = [cluster, apiVersion, kind,
namespace, name, creationTimestamp, deletionTimestamp, ownerReferences,
resourceVersion, labels.[key], annotations.[key], content]
2. find the schema_links for generating SQL queries for each question based on the database schema.
Follow are some examples.
Q: find the kind which is not equal to pod
A: Let’s think step by step. In the question "find the kind column which is not equal to pod", we are asked:
"find the kind" so we need column = [kind]
Based on the columns, the set of possible cell values are = [pod].
So the Schema_links are:
Schema_links: [kind, pod]
Q: find the kind Deployment which created before January 1, 2024, at 18:00:00
A: Let’s think step by step. In the question "find the kind Deployment which created before January 1, 2024, at 18:00:00", we are asked:
"find the kind Deployment" so we need column = [kind]
"created before" so we need column = [creationTimestamp]
Based on the columns, the set of possible cell values are = [Deployment, 2024-01-01T18:00:00Z].
So the Schema_links are:
Schema_links: [kind, creationTimestamp, Deployment, 2024-01-01T18:00:00Z]
3. Use the the schema links to generate the SQL queries for each of the questions.
Follow are some examples.
Q: find the kind which is not equal to pod
Schema_links: [kind, pod]
SQL: select * from resources where kind!='Pod';
Q: find the kind Deployment which created before January 1, 2024, at 18:00:00
Schema_links: [kind, creationTimestamp, Deployment, 2024-01-01T18:00:00Z]
SQL: select * from resources where kind='Deployment' and creationTimestamp < '2024-01-01T18:00:00Z';
Q: find the namespace which does not contain banan
Schema_links: [namespace, banan]
SQL: select * from resources where namespace notlike 'banan_';
Please convert the text to sql.
`

sql_fix_prompt = `
You are an AI specialized in writing SQL queries.
Please convert the text %s to sql.
The SQL should begin with "select * from".
The database now only supports one table resources.
Table resources, columns = [cluster, apiVersion, kind,
namespace, name, creationTimestamp, deletionTimestamp, ownerReferences,
resourceVersion, labels.[key], annotations.[key], content]
After we executed SQL %s, we observed the following error %s.
Please fix the SQL.
`
)

var ServicePromptMap = map[string]string{
"default": default_prompt,
"search": search_prompt,
"default": default_prompt,
"Text2sql": text2sql_prompt,
"SqlFix": sql_fix_prompt,
}
7 changes: 6 additions & 1 deletion pkg/infra/ai/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ const (
OpenAIProvider = "openai"
)

const (
Text2sqlType = "Text2sql"
SqlFixType = "SqlFix"
)

var clients = map[string]AIProvider{
AzureProvider: &AzureAIClient{},
HuggingFaceProvider: &HuggingfaceClient{},
Expand All @@ -37,7 +42,7 @@ type AIProvider interface {
Configure(config AIConfig) error
// Generate generates a response from the AI service based on
// the provided prompt and service type.
Generate(ctx context.Context, prompt string, serviceType string) (string, error)
Generate(ctx context.Context, prompt string) (string, error)
}

// AIConfig represents the configuration settings for an AI client.
Expand Down
2 changes: 1 addition & 1 deletion pkg/infra/search/storage/elasticsearch/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (s *Storage) Search(ctx context.Context, queryStr string, patternType strin
if err != nil {
return nil, errors.Wrap(err, "search by DSL failed")
}
case storage.SQLPatternType:
case storage.SQLPatternType, storage.NLPatternType:
sr, err = s.searchBySQL(ctx, queryStr, pagination)
if err != nil {
return nil, errors.Wrap(err, "search by SQL failed")
Expand Down
1 change: 1 addition & 0 deletions pkg/infra/search/storage/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (

const (
Equals = "="
NLPatternType = "nl"
DSLPatternType = "dsl"
SQLPatternType = "sql"
)
Expand Down

0 comments on commit 0ddb27a

Please sign in to comment.