Skip to content

Commit

Permalink
feat: add gemini pro, gemini pro vision models
Browse files Browse the repository at this point in the history
  • Loading branch information
zmh-program committed Dec 28, 2023
1 parent 3d1e8c8 commit 261e500
Show file tree
Hide file tree
Showing 12 changed files with 304 additions and 27 deletions.
82 changes: 67 additions & 15 deletions adapter/palm2/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,23 @@ import (
"fmt"
)

var geminiMaxImages = 16

type ChatProps struct {
Model string
Message []globals.Message
Model string
Message []globals.Message
Temperature *float64
TopP *float64
TopK *int
MaxOutputTokens *int
}

func (c *ChatInstance) GetChatEndpoint(model string) string {
return fmt.Sprintf("%s/v1beta2/models/%s:generateMessage?key=%s", c.Endpoint, model, c.ApiKey)
if model == globals.ChatBison001 {
return fmt.Sprintf("%s/v1beta2/models/%s:generateMessage?key=%s", c.Endpoint, model, c.ApiKey)
}

return fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", c.Endpoint, model, c.ApiKey)
}

func (c *ChatInstance) ConvertMessage(message []globals.Message) []PalmMessage {
Expand Down Expand Up @@ -41,31 +51,73 @@ func (c *ChatInstance) ConvertMessage(message []globals.Message) []PalmMessage {
return result
}

func (c *ChatInstance) GetChatBody(props *ChatProps) *ChatBody {
return &ChatBody{
Prompt: Prompt{
func (c *ChatInstance) GetPalm2ChatBody(props *ChatProps) *PalmChatBody {
return &PalmChatBody{
Prompt: PalmPrompt{
Messages: c.ConvertMessage(props.Message),
},
}
}

func (c *ChatInstance) GetGeminiChatBody(props *ChatProps) *GeminiChatBody {
return &GeminiChatBody{
Contents: c.GetGeminiContents(props.Model, props.Message),
GenerationConfig: GeminiConfig{
Temperature: props.Temperature,
MaxOutputTokens: props.MaxOutputTokens,
TopP: props.TopP,
TopK: props.TopK,
},
}
}

func (c *ChatInstance) GetPalm2ChatResponse(data interface{}) (string, error) {
if form := utils.MapToStruct[PalmChatResponse](data); form != nil {
if len(form.Candidates) == 0 {
return "", fmt.Errorf("palm2 error: the content violates content policy")
}
return form.Candidates[0].Content, nil
}
return "", fmt.Errorf("palm2 error: cannot parse response")
}

func (c *ChatInstance) GetGeminiChatResponse(data interface{}) (string, error) {
if form := utils.MapToStruct[GeminiChatResponse](data); form != nil {
if len(form.Candidates) != 0 && len(form.Candidates[0].Content.Parts) != 0 {
return form.Candidates[0].Content.Parts[0].Text, nil
}
}

if form := utils.MapToStruct[GeminiChatErrorResponse](data); form != nil {
return "", fmt.Errorf("gemini error: %s (code: %d, status: %s)", form.Error.Message, form.Error.Code, form.Error.Status)
}

return "", fmt.Errorf("gemini: cannot parse response")
}

func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
uri := c.GetChatEndpoint(props.Model)

if props.Model == globals.ChatBison001 {
data, err := utils.Post(uri, map[string]string{
"Content-Type": "application/json",
}, c.GetPalm2ChatBody(props))

if err != nil {
return "", fmt.Errorf("palm2 error: %s", err.Error())
}
return c.GetPalm2ChatResponse(data)
}

data, err := utils.Post(uri, map[string]string{
"Content-Type": "application/json",
}, c.GetChatBody(props))
}, c.GetGeminiChatBody(props))

if err != nil {
return "", fmt.Errorf("palm2 error: %s", err.Error())
return "", fmt.Errorf("gemini error: %s", err.Error())
}

if form := utils.MapToStruct[ChatResponse](data); form != nil {
if len(form.Candidates) == 0 {
return "I don't know how to respond to that. Please try another question.", nil
}
return form.Candidates[0].Content, nil
}
return "", fmt.Errorf("palm2 error: cannot parse response")
return c.GetGeminiChatResponse(data)
}

// CreateStreamChatRequest is the mock stream request for palm2
Expand Down
106 changes: 106 additions & 0 deletions adapter/palm2/formatter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package palm2

import (
"chat/globals"
"chat/utils"
"strings"
)

func getGeminiRole(role string) string {
switch role {
case globals.User:
return GeminiUserType
case globals.Assistant, globals.Tool, globals.System:
return GeminiModelType
default:
return GeminiUserType
}
}

func getMimeType(content string) string {
segment := strings.Split(content, ".")
if len(segment) == 0 || len(segment) == 1 {
return "image/png"
}

suffix := strings.TrimSpace(strings.ToLower(segment[len(segment)-1]))

switch suffix {
case "png":
return "image/png"
case "jpg", "jpeg":
return "image/jpeg"
case "gif":
return "image/gif"
case "webp":
return "image/webp"
case "heif":
return "image/heif"
case "heic":
return "image/heic"
default:
return "image/png"
}
}

func getGeminiContent(parts []GeminiChatPart, content string, model string) []GeminiChatPart {
parts = append(parts, GeminiChatPart{
Text: &content,
})

if model == globals.GeminiPro {
return parts
}

urls := utils.ExtractImageUrls(content)
if len(urls) > geminiMaxImages {
urls = urls[:geminiMaxImages]
}

for _, url := range urls {
data, err := utils.ConvertToBase64(url)
if err != nil {
continue
}

parts = append(parts, GeminiChatPart{
InlineData: &GeminiInlineData{
MimeType: getMimeType(url),
Data: data,
},
})
}

return parts
}

func (c *ChatInstance) GetGeminiContents(model string, message []globals.Message) []GeminiContent {
// gemini role should be user-model

result := make([]GeminiContent, 0)
for _, item := range message {
role := getGeminiRole(item.Role)
if len(item.Content) == 0 {
// gemini model: message must include non empty content
continue
}

if len(result) == 0 && getGeminiRole(item.Role) == GeminiModelType {
// gemini model: first message must be user
continue
}

if len(result) > 0 && role == result[len(result)-1].Role {
// gemini model: messages must alternate between authors
result[len(result)-1].Parts = getGeminiContent(result[len(result)-1].Parts, item.Content, model)
continue
}

result = append(result, GeminiContent{
Role: getGeminiRole(item.Role),
Parts: getGeminiContent(make([]GeminiChatPart, 0), item.Content, model),
})
}

return result
}
64 changes: 58 additions & 6 deletions adapter/palm2/types.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,72 @@
package palm2

const (
GeminiUserType = "user"
GeminiModelType = "model"
)

type PalmMessage struct {
Author string `json:"author"`
Content string `json:"content"`
}

// ChatBody is the native http request body for palm2
type ChatBody struct {
Prompt Prompt `json:"prompt"`
// PalmChatBody is the native http request body for palm2
type PalmChatBody struct {
Prompt PalmPrompt `json:"prompt"`
}

type Prompt struct {
type PalmPrompt struct {
Messages []PalmMessage `json:"messages"`
}

// ChatResponse is the native http response body for palm2
type ChatResponse struct {
// PalmChatResponse is the native http response body for palm2
type PalmChatResponse struct {
Candidates []PalmMessage `json:"candidates"`
}

// GeminiChatBody is the native http request body for gemini
type GeminiChatBody struct {
Contents []GeminiContent `json:"contents"`
GenerationConfig GeminiConfig `json:"generationConfig"`
}

type GeminiConfig struct {
Temperature *float64 `json:"temperature,omitempty"`
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
TopP *float64 `json:"topP,omitempty"`
TopK *int `json:"topK,omitempty"`
}

type GeminiContent struct {
Role string `json:"role"`
Parts []GeminiChatPart `json:"parts"`
}

type GeminiChatPart struct {
Text *string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inline_data,omitempty"`
}

type GeminiInlineData struct {
MimeType string `json:"mime_type"`
Data string `json:"data"`
}

type GeminiChatResponse struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
Role string `json:"role"`
} `json:"content"`
} `json:"candidates"`
}

type GeminiChatErrorResponse struct {
Error struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
} `json:"error"`
}
Binary file added app/public/icons/gemini.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions app/src/admin/channel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ export const ChannelTypes: Record<string, string> = {
baichuan: "百川 AI",
skylark: "火山方舟",
bing: "New Bing",
palm: "Google PaLM2",
palm: "Google Gemini",
midjourney: "Midjourney",
oneapi: "Nio API",
};
Expand Down Expand Up @@ -141,7 +141,7 @@ export const ChannelInfos: Record<string, ChannelInfo> = {
id: 11,
endpoint: "https://generativelanguage.googleapis.com",
format: "<api-key>",
models: ["chat-bison-001"],
models: ["chat-bison-001", "gemini-pro", "gemini-pro-vision"],
},
midjourney: {
id: 12,
Expand Down
2 changes: 2 additions & 0 deletions app/src/admin/colors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ export const modelColorMapper: Record<string, string> = {
"spark-desk-v3": "#06b3e8",

"chat-bison-001": "#f82a53",
"gemini-pro": "#f82a53",
"gemini-pro-vision": "#f82a53",

"bing-creative": "#2673e7",
"bing-balanced": "#2673e7",
Expand Down
21 changes: 21 additions & 0 deletions app/src/conf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,22 @@ export const supportModels: Model[] = [
tag: ["free", "english-model"],
},

// gemini
{
id: "gemini-pro",
name: "Gemini Pro",
free: true,
auth: true,
tag: ["free", "official"],
},
{
id: "gemini-pro-vision",
name: "Gemini Pro Vision",
free: true,
auth: true,
tag: ["free", "official", "multi-modal"],
},

// drawing models
{
id: "midjourney",
Expand Down Expand Up @@ -346,6 +362,9 @@ export const defaultModels = [
"zhipu-chatglm-turbo",
"baichuan-53b",

"gemini-pro",
"gemini-pro-vision",

"dall-e-2",
"midjourney-fast",
"stable-diffusion",
Expand Down Expand Up @@ -412,6 +431,8 @@ export const modelAvatars: Record<string, string> = {
"midjourney-turbo": "midjourney.jpg",
"bing-creative": "newbing.jpg",
"chat-bison-001": "palm2.webp",
"gemini-pro": "gemini.jpeg",
"gemini-pro-vision": "gemini.jpeg",
"zhipu-chatglm-turbo": "chatglm.png",
"qwen-plus-net": "tongyi.png",
"qwen-plus": "tongyi.png",
Expand Down
4 changes: 2 additions & 2 deletions app/src/routes/admin/System.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ function Mail({ data, dispatch, onChange }: CompProps<MailState>) {
<ParagraphFooter>
<div className={`grow`} />
<Dialog open={mailDialog} onOpenChange={setMailDialog}>
<DialogTrigger>
<Button variant={`outline`} size={`sm`} loading={true}>
<DialogTrigger asChild>
<Button variant={`outline`} size={`sm`}>
{t("admin.system.test")}
</Button>
</DialogTrigger>
Expand Down
2 changes: 2 additions & 0 deletions globals/variables.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ const (
SparkDeskV2 = "spark-desk-v2"
SparkDeskV3 = "spark-desk-v3"
ChatBison001 = "chat-bison-001"
GeminiPro = "gemini-pro"
GeminiProVision = "gemini-pro-vision"
BingCreative = "bing-creative"
BingBalanced = "bing-balanced"
BingPrecise = "bing-precise"
Expand Down
4 changes: 2 additions & 2 deletions utils/char.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ func ExtractUrls(data string) []string {
func ExtractImageUrls(data string) []string {
// https://platform.openai.com/docs/guides/vision/what-type-of-files-can-i-upload

re := regexp.MustCompile(`(https?://\S+\.(?:png|jpg|jpeg|gif|webp))`)
return re.FindAllString(data, -1)
re := regexp.MustCompile(`(https?://\S+\.(?:png|jpg|jpeg|gif|webp|heif|heic))`)
return re.FindAllString(strings.ToLower(data), -1)
}

func ContainUnicode(data string) bool {
Expand Down
Loading

0 comments on commit 261e500

Please sign in to comment.