Skip to content

Commit

Permalink
feat: add ai manager
Browse files Browse the repository at this point in the history
  • Loading branch information
jueli12 committed Aug 25, 2024
1 parent eec1100 commit 528bc27
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 0 deletions.
43 changes: 43 additions & 0 deletions pkg/core/manager/ai/search.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright The Karpor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ai

import (
"context"
"fmt"
"github.com/KusionStack/karpor/pkg/infra/ai"
)

// ConvertTextToSQL converts natural language text to an SQL query
func (a *AIManager) ConvertTextToSQL(query string) (string, error) {
servicePrompt := ai.ServicePromptMap[ai.Text2sqlType]
prompt := fmt.Sprintf(servicePrompt, query)
res, err := a.client.Generate(context.Background(), prompt)
if err != nil {
return "", err
}
return ExtractSelectSQL(res), nil
}

// FixSQL fix the error SQL
func (a *AIManager) FixSQL(sql string, query string, error string) (string, error) {
servicePrompt := ai.ServicePromptMap[ai.SqlFixType]
prompt := fmt.Sprintf(servicePrompt, query, sql, error)
res, err := a.client.Generate(context.Background(), prompt)
if err != nil {
return "", err
}
return ExtractSelectSQL(res), nil
}
24 changes: 24 additions & 0 deletions pkg/core/manager/ai/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright The Karpor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ai

import "regexp"

// ExtractSelectSQL extracts SQL statements that start with "SELECT * FROM"
func ExtractSelectSQL(sql string) string {
res := regexp.MustCompile(`(?i)SELECT \* FROM [^;]+`)
match := res.FindString(sql)
return match
}
44 changes: 44 additions & 0 deletions pkg/core/manager/ai/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright The Karpor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ai

import (
"github.com/stretchr/testify/require"
"testing"
)

// TestExtractSelectSQL tests the correctness of the ExtractSelectSQL function.
func TestExtractSelectSQL(t *testing.T) {
testCases := []struct {
name string
sql string
expected string
}{
{
name: "NormalCase",
sql: "Q: 所有kind=namespace " +
"Schema_links: [kind, namespace] " +
"SQL: select * from resources where kind='namespace'",
expected: "select * from resources where kind='namespace'",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actual := ExtractSelectSQL(tc.sql)
require.Equal(t, tc.expected, actual)
})
}
}

0 comments on commit 528bc27

Please sign in to comment.