From 88b79db7212051a90a91027445db77de90dd5f3c Mon Sep 17 00:00:00 2001 From: MartialBE Date: Fri, 12 Jul 2024 21:09:12 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=96=20chore:=20support=20vertexai?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/config/constants.go | 2 + go.mod | 25 +++- go.sum | 54 ++++++++- providers/bedrock/category/base.go | 13 +- providers/claude/base.go | 4 +- providers/gemini/base.go | 11 +- providers/gemini/chat.go | 24 ++-- providers/providers.go | 2 + providers/vertexai/base.go | 165 ++++++++++++++++++++++++++ providers/vertexai/category/base.go | 44 +++++++ providers/vertexai/category/claude.go | 87 ++++++++++++++ providers/vertexai/category/gemini.go | 63 ++++++++++ providers/vertexai/chat.go | 80 +++++++++++++ providers/vertexai/type.go | 29 +++++ web/src/constants/ChannelConstants.js | 7 ++ web/src/views/Channel/type/Config.js | 11 ++ 16 files changed, 596 insertions(+), 25 deletions(-) create mode 100644 providers/vertexai/base.go create mode 100644 providers/vertexai/category/base.go create mode 100644 providers/vertexai/category/claude.go create mode 100644 providers/vertexai/category/gemini.go create mode 100644 providers/vertexai/chat.go create mode 100644 providers/vertexai/type.go diff --git a/common/config/constants.go b/common/config/constants.go index bccfa67eb..5362b1a66 100644 --- a/common/config/constants.go +++ b/common/config/constants.go @@ -184,6 +184,7 @@ const ( ChannelTypeOllama = 39 ChannelTypeHunyuan = 40 ChannelTypeSuno = 41 + ChannelTypeVertexAI = 42 ) var ChannelBaseURLs = []string{ @@ -229,6 +230,7 @@ var ChannelBaseURLs = []string{ "", //39 "https://hunyuan.tencentcloudapi.com", //40 "", //41 + "", //42 } const ( diff --git a/go.mod b/go.mod index 804f1618c..cd5968823 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ go 1.22 toolchain go1.22.3 require ( + cloud.google.com/go/iam v1.1.11 github.com/aliyun/aliyun-oss-go-sdk v3.0.2+incompatible github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.1 github.com/aws/smithy-go v1.20.1 @@ -38,6 +39,7 @@ require ( github.com/wneessen/go-mail v0.4.1 golang.org/x/crypto v0.25.0 golang.org/x/image v0.18.0 + google.golang.org/api v0.188.0 gorm.io/driver/mysql v1.5.5 gorm.io/driver/postgres v1.5.7 gorm.io/driver/sqlite v1.5.5 @@ -45,13 +47,22 @@ require ( ) require ( + cloud.google.com/go/auth v0.7.0 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect + cloud.google.com/go/compute/metadata v0.4.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/chenzhuoyu/iasm v0.9.1 // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/go-logr/logr v1.4.1 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/mock v1.6.0 // indirect - github.com/golang/protobuf v1.5.3 // indirect - github.com/google/go-cmp v0.6.0 // indirect + github.com/golang/protobuf v1.5.4 // indirect + github.com/google/s2a-go v0.1.7 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect + github.com/googleapis/gax-go/v2 v2.12.5 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jonboulle/clockwork v0.4.0 // indirect @@ -74,11 +85,21 @@ require ( github.com/subosito/gotenv v1.6.0 // indirect github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect + go.opencensus.io v0.24.0 // indirect + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect + go.opentelemetry.io/otel v1.24.0 // indirect + go.opentelemetry.io/otel/metric v1.24.0 // indirect + go.opentelemetry.io/otel/trace v1.24.0 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f // indirect + golang.org/x/oauth2 v0.21.0 // indirect golang.org/x/sync v0.7.0 // indirect golang.org/x/time v0.5.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240708141625-4ad9e859172b // indirect + google.golang.org/grpc v1.64.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect ) diff --git a/go.sum b/go.sum index f64c152bb..d09fc194c 100644 --- a/go.sum +++ b/go.sum @@ -13,14 +13,22 @@ cloud.google.com/go v0.56.0/go.mod h1:jr7tqZxxKOVYizybht9+26Z/gUq7tiRzu+ACVAMbKV cloud.google.com/go v0.57.0/go.mod h1:oXiQ6Rzq3RAkkY7N6t3TcE6jE+CIBBbA36lwQ1JyzZs= cloud.google.com/go v0.62.0/go.mod h1:jmCYTdRCQuc1PHIIJ/maLInMho30T/Y0M4hTdTShOYc= cloud.google.com/go v0.65.0/go.mod h1:O5N8zS7uWy9vkA9vayVHs65eM1ubvY4h553ofrNHObY= +cloud.google.com/go/auth v0.7.0 h1:kf/x9B3WTbBUHkC+1VS8wwwli9TzhSt0vSTVBmMR8Ts= +cloud.google.com/go/auth v0.7.0/go.mod h1:D+WqdrpcjmiCgWrXmLLxOVq1GACoE36chW6KXoEvuIw= +cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4= +cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q= cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE= cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc= cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg= cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= +cloud.google.com/go/compute/metadata v0.4.0 h1:vHzJCWaM4g8XIcm8kopr3XmDA4Gy/lblD3EhhSux05c= +cloud.google.com/go/compute/metadata v0.4.0/go.mod h1:SIQh1Kkb4ZJ8zJ874fqVkslA29PRXuleyj6vOzlbK7M= cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= +cloud.google.com/go/iam v1.1.11 h1:0mQ8UKSfdHLut6pH9FM3bI55KWR46ketn0PuXleDyxw= +cloud.google.com/go/iam v1.1.11/go.mod h1:biXoiLWYIKntto2joP+62sd9uW5EpkZmKIvfNcTWlnQ= cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw= cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA= @@ -106,6 +114,8 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= @@ -142,6 +152,11 @@ github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9 github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= +github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= @@ -175,6 +190,8 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfU github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= @@ -200,8 +217,8 @@ github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= -github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/gomarkdown/markdown v0.0.0-20240328165702-4d01890c35c0 h1:4gjrh/PN2MuWCCElk8/I4OCKRKWCCo2zEct3VKCbibU= github.com/gomarkdown/markdown v0.0.0-20240328165702-4d01890c35c0/go.mod h1:JDGcbDT52eL4fju3sZ4TeHGsQwhG9nbDV21aMyhwPoA= github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= @@ -214,6 +231,7 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= @@ -231,10 +249,17 @@ github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hf github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= +github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw= +github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs= +github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= +github.com/googleapis/gax-go/v2 v2.12.5 h1:8gw9KZK8TiVKB6q3zHY3SBzLnrGp6HQjyfYBYGmXdxA= +github.com/googleapis/gax-go/v2 v2.12.5/go.mod h1:BUDKcWo+RaKq5SC9vVYL0wLADa3VcfswbOMMRmB9H3E= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/context v1.1.2 h1:WRkNAv2uoa03QNIc1A6u4O7DAGMUVoopZhkiXWA2V1o= github.com/gorilla/context v1.1.2/go.mod h1:KDPwT9i/MeWHiLl90fuTgrt4/wPcv75vFAZLaOOcbxM= @@ -443,6 +468,18 @@ go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= +go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= +go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw= +go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo= +go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo= +go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGXlc88kI= +go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco= +go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI= +go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= @@ -524,6 +561,7 @@ golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/ golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= @@ -538,6 +576,8 @@ golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4Iltr golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b/go.mod h1:DAh4E804XQdzx2j+YRIaUnCqCV2RuMz24cGBJ5QYIrc= +golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs= +golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -582,6 +622,7 @@ golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -674,6 +715,8 @@ google.golang.org/api v0.24.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0M google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM= google.golang.org/api v0.30.0/go.mod h1:QGmEvQ87FHZNiUVJkT14jQNYJ4ZJjdRF23ZXz5138Fc= +google.golang.org/api v0.188.0 h1:51y8fJ/b1AaaBRJr4yWm96fPcuxSo0JcegXE3DaHQHw= +google.golang.org/api v0.188.0/go.mod h1:VR0d+2SIiWOYG3r/jdm7adPW9hI2aRv9ETOSCQ9Beag= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -709,6 +752,10 @@ google.golang.org/genproto v0.0.0-20200618031413-b414f8b61790/go.mod h1:jDfRM7Fc google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094 h1:0+ozOGcrp+Y8Aq8TLNN2Aliibms5LEzsq99ZZmAGYm0= +google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094/go.mod h1:fJ/e3If/Q67Mj99hin0hMhiNyCRmt6BQ2aWIJshUSJw= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240708141625-4ad9e859172b h1:04+jVzTs2XBnOZcPsLnmrTGqltqJbZQ1Ey26hjYdQQ0= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240708141625-4ad9e859172b/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -721,6 +768,9 @@ google.golang.org/grpc v1.28.0/go.mod h1:rpkK4SK4GF4Ach/+MFLZUBavHOvF2JJB5uozKKa google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= google.golang.org/grpc v1.30.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= +google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA= +google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/providers/bedrock/category/base.go b/providers/bedrock/category/base.go index d362470a3..6b2d62d06 100644 --- a/providers/bedrock/category/base.go +++ b/providers/bedrock/category/base.go @@ -34,12 +34,13 @@ func GetCategory(modelName string) (*Category, error) { func GetModelName(modelName string) string { bedrockMap := map[string]string{ - "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0", - "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", - "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", - "claude-2.1": "anthropic.claude-v2:1", - "claude-2.0": "anthropic.claude-v2", - "claude-instant-1.2": "anthropic.claude-instant-v1", + "claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0", + "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", + "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", + "claude-2.1": "anthropic.claude-v2:1", + "claude-2.0": "anthropic.claude-v2", + "claude-instant-1.2": "anthropic.claude-instant-v1", } if value, exists := bedrockMap[modelName]; exists { diff --git a/providers/claude/base.go b/providers/claude/base.go index 260d70be7..1f505d1e0 100644 --- a/providers/claude/base.go +++ b/providers/claude/base.go @@ -19,7 +19,7 @@ func (f ClaudeProviderFactory) Create(channel *model.Channel) base.ProviderInter BaseProvider: base.BaseProvider{ Config: getConfig(), Channel: channel, - Requester: requester.NewHTTPRequester(*channel.Proxy, requestErrorHandle), + Requester: requester.NewHTTPRequester(*channel.Proxy, RequestErrorHandle), }, } } @@ -36,7 +36,7 @@ func getConfig() base.ProviderConfig { } // 请求错误处理 -func requestErrorHandle(resp *http.Response) *types.OpenAIError { +func RequestErrorHandle(resp *http.Response) *types.OpenAIError { claudeError := &ClaudeError{} err := json.NewDecoder(resp.Body).Decode(claudeError) if err != nil { diff --git a/providers/gemini/base.go b/providers/gemini/base.go index 364f9f804..11ef0098c 100644 --- a/providers/gemini/base.go +++ b/providers/gemini/base.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "net/http" + "one-api/common/logger" "one-api/common/requester" "one-api/model" "one-api/providers/base" @@ -19,7 +20,7 @@ func (f GeminiProviderFactory) Create(channel *model.Channel) base.ProviderInter BaseProvider: base.BaseProvider{ Config: getConfig(), Channel: channel, - Requester: requester.NewHTTPRequester(*channel.Proxy, requestErrorHandle), + Requester: requester.NewHTTPRequester(*channel.Proxy, RequestErrorHandle), }, } } @@ -37,7 +38,7 @@ func getConfig() base.ProviderConfig { } // 请求错误处理 -func requestErrorHandle(resp *http.Response) *types.OpenAIError { +func RequestErrorHandle(resp *http.Response) *types.OpenAIError { geminiError := &GeminiErrorResponse{} err := json.NewDecoder(resp.Body).Decode(geminiError) if err != nil { @@ -52,6 +53,12 @@ func errorHandle(geminiError *GeminiErrorResponse) *types.OpenAIError { if geminiError.Error.Message == "" { return nil } + + if strings.Contains(geminiError.Error.Message, "Publisher Model") { + logger.SysError(fmt.Sprintf("Gemini Error: %s", geminiError.Error.Message)) + geminiError.Error.Message = "上游错误,请联系管理员." + } + return &types.OpenAIError{ Message: geminiError.Error.Message, Type: "gemini_error", diff --git a/providers/gemini/chat.go b/providers/gemini/chat.go index 566586e6b..e28feb81a 100644 --- a/providers/gemini/chat.go +++ b/providers/gemini/chat.go @@ -7,6 +7,7 @@ import ( "one-api/common" "one-api/common/requester" "one-api/common/utils" + "one-api/providers/base" "one-api/types" "strings" ) @@ -15,7 +16,7 @@ const ( GeminiVisionMaxImageNum = 16 ) -type geminiStreamHandler struct { +type GeminiStreamHandler struct { Usage *types.Usage LastCandidates int LastType string @@ -36,7 +37,7 @@ func (p *GeminiProvider) CreateChatCompletion(request *types.ChatCompletionReque return nil, errWithCode } - return p.convertToChatOpenai(geminiChatResponse, request) + return ConvertToChatOpenai(p, geminiChatResponse, request) } func (p *GeminiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { @@ -52,14 +53,14 @@ func (p *GeminiProvider) CreateChatCompletionStream(request *types.ChatCompletio return nil, errWithCode } - chatHandler := &geminiStreamHandler{ + chatHandler := &GeminiStreamHandler{ Usage: p.Usage, LastCandidates: 0, LastType: "", Request: request, } - return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream) + return requester.RequestStream[string](p.Requester, resp, chatHandler.HandlerStream) } func (p *GeminiProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { @@ -76,7 +77,7 @@ func (p *GeminiProvider) getChatRequest(request *types.ChatCompletionRequest) (* headers["Accept"] = "text/event-stream" } - geminiRequest, errWithCode := convertFromChatOpenai(request) + geminiRequest, errWithCode := ConvertFromChatOpenai(request) if errWithCode != nil { return nil, errWithCode } @@ -92,7 +93,7 @@ func (p *GeminiProvider) getChatRequest(request *types.ChatCompletionRequest) (* return req, nil } -func convertFromChatOpenai(request *types.ChatCompletionRequest) (*GeminiChatRequest, *types.OpenAIErrorWithStatusCode) { +func ConvertFromChatOpenai(request *types.ChatCompletionRequest) (*GeminiChatRequest, *types.OpenAIErrorWithStatusCode) { request.ClearEmptyMessages() geminiRequest := GeminiChatRequest{ Contents: make([]GeminiChatContent, 0, len(request.Messages)), @@ -141,7 +142,7 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) (*GeminiChatReq return &geminiRequest, nil } -func (p *GeminiProvider) convertToChatOpenai(response *GeminiChatResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { +func ConvertToChatOpenai(provider base.ProviderInterface, response *GeminiChatResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { aiError := errorHandle(&response.GeminiErrorResponse) if aiError != nil { errWithCode = &types.OpenAIErrorWithStatusCode{ @@ -162,14 +163,15 @@ func (p *GeminiProvider) convertToChatOpenai(response *GeminiChatResponse, reque openaiResponse.Choices = append(openaiResponse.Choices, candidate.ToOpenAIChoice(request)) } - *p.Usage = convertOpenAIUsage(request.Model, response.UsageMetadata) - openaiResponse.Usage = p.Usage + usage := provider.GetUsage() + *usage = convertOpenAIUsage(request.Model, response.UsageMetadata) + openaiResponse.Usage = usage return } // 转换为OpenAI聊天流式请求体 -func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) { +func (h *GeminiStreamHandler) HandlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) { // 如果rawLine 前缀不为data:,则直接返回 if !strings.HasPrefix(string(*rawLine), "data: ") { *rawLine = nil @@ -196,7 +198,7 @@ func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, dataChan chan strin } -func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, dataChan chan string) { +func (h *GeminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, dataChan chan string) { streamResponse := types.ChatCompletionStreamResponse{ ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()), Object: "chat.completion.chunk", diff --git a/providers/providers.go b/providers/providers.go index 86c0a5f5b..15f619da8 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -29,6 +29,7 @@ import ( "one-api/providers/stabilityAI" "one-api/providers/suno" "one-api/providers/tencent" + "one-api/providers/vertexai" "one-api/providers/xunfei" "one-api/providers/zhipu" @@ -72,6 +73,7 @@ func init() { providerFactories[config.ChannelTypeLingyi] = lingyi.LingyiProviderFactory{} providerFactories[config.ChannelTypeHunyuan] = hunyuan.HunyuanProviderFactory{} providerFactories[config.ChannelTypeSuno] = suno.SunoProviderFactory{} + providerFactories[config.ChannelTypeVertexAI] = vertexai.VertexAIProviderFactory{} } diff --git a/providers/vertexai/base.go b/providers/vertexai/base.go new file mode 100644 index 000000000..19865c5a5 --- /dev/null +++ b/providers/vertexai/base.go @@ -0,0 +1,165 @@ +package vertexai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common/cache" + "one-api/common/logger" + "one-api/common/requester" + "one-api/model" + "one-api/providers/base" + "one-api/providers/vertexai/category" + "one-api/types" + "strings" + "time" + + credentials "cloud.google.com/go/iam/credentials/apiv1" + "cloud.google.com/go/iam/credentials/apiv1/credentialspb" + "google.golang.org/api/option" +) + +const TokenCacheKey = "api_token:vertexai" +const defaultScope = "https://www.googleapis.com/auth/cloud-platform" + +type VertexAIProviderFactory struct{} + +// 创建 VertexAIProvider +func (f VertexAIProviderFactory) Create(channel *model.Channel) base.ProviderInterface { + vertexAIProvider := &VertexAIProvider{ + BaseProvider: base.BaseProvider{ + Config: getConfig(), + Channel: channel, + Requester: requester.NewHTTPRequester(*channel.Proxy, nil), + }, + } + + getKeyConfig(vertexAIProvider) + + return vertexAIProvider +} + +type VertexAIProvider struct { + base.BaseProvider + Region string + ProjectID string + Category *category.Category +} + +func getConfig() base.ProviderConfig { + return base.ProviderConfig{ + BaseURL: "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", + ChatCompletions: "/", + } +} + +func getKeyConfig(vertexAI *VertexAIProvider) { + keys := strings.Split(vertexAI.Channel.Other, "|") + if len(keys) != 2 { + return + } + + vertexAI.Region = keys[0] + vertexAI.ProjectID = keys[1] +} + +func (p *VertexAIProvider) GetFullRequestURL(modelName string, other string) string { + return fmt.Sprintf(p.GetBaseURL(), p.Region, p.ProjectID, p.Region, modelName, other) +} + +func (p *VertexAIProvider) GetRequestHeaders() (headers map[string]string) { + headers = make(map[string]string) + p.CommonRequestHeaders(headers) + + token, err := p.GetToken() + if err != nil { + logger.SysError("Failed to get token: " + err.Error()) + return headers + } + + headers["Authorization"] = "Bearer " + token + + return headers +} + +func (p *VertexAIProvider) GetToken() (string, error) { + cacheKey := fmt.Sprintf("%s:%s", TokenCacheKey, p.ProjectID) + token, err := cache.GetCache[string](cacheKey) + if err != nil { + logger.SysError("Failed to get token from cache: " + err.Error()) + } + + if token != "" { + return token, nil + } + + creds := &Credentials{} + if err := json.Unmarshal([]byte(p.Channel.Key), creds); err != nil { + return "", fmt.Errorf("failed to unmarshal credentials: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + client, err := credentials.NewIamCredentialsClient(ctx, option.WithCredentialsJSON([]byte(p.Channel.Key))) + if err != nil { + return "", fmt.Errorf("failed to create IAM credentials client: %w", err) + } + defer client.Close() + + req := &credentialspb.GenerateAccessTokenRequest{ + Name: fmt.Sprintf("projects/-/serviceAccounts/%s", creds.ClientEmail), + Scope: []string{defaultScope}, + } + + resp, err := client.GenerateAccessToken(ctx, req) + if err != nil { + return "", fmt.Errorf("failed to generate access token: %w", err) + } + + duration := time.Until(resp.ExpireTime.AsTime()) + cache.SetCache(cacheKey, resp.AccessToken, duration) + + return resp.AccessToken, nil +} + +func RequestErrorHandle(otherErr requester.HttpErrorHandler) requester.HttpErrorHandler { + + return func(resp *http.Response) *types.OpenAIError { + requestBody, _ := io.ReadAll(resp.Body) + resp.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + if otherErr != nil { + err := otherErr(resp) + if err != nil { + return err + } + } + vertexaiErrors := &VertexaiErrors{} + err := json.Unmarshal(requestBody, vertexaiErrors) + if err != nil { + return nil + } + vertexaiError := vertexaiErrors.Error() + + return errorHandle(vertexaiError) + + } +} + +func errorHandle(vertexaiError *VertexaiError) *types.OpenAIError { + if vertexaiError.Error.Message == "" { + return nil + } + + logger.SysError(fmt.Sprintf("VertexAI error: %s", vertexaiError.Error.Message)) + + return &types.OpenAIError{ + Message: "VertexAI错误", + Type: "gemini_error", + Param: vertexaiError.Error.Status, + Code: vertexaiError.Error.Code, + } +} diff --git a/providers/vertexai/category/base.go b/providers/vertexai/category/base.go new file mode 100644 index 000000000..bbb1738c8 --- /dev/null +++ b/providers/vertexai/category/base.go @@ -0,0 +1,44 @@ +package category + +import ( + "errors" + "net/http" + "one-api/common/requester" + "one-api/providers/base" + "one-api/types" + "strings" +) + +type Category struct { + ChatComplete ChatCompletionConvert + ResponseChatComplete ChatCompletionResponse + ResponseChatCompleteStrem ChatCompletionStreamResponse + ErrorHandler requester.HttpErrorHandler + GetModelName func(string) string + GetOtherUrl func(bool) string +} + +var CategoryMap = map[string]*Category{} + +func GetCategory(modelName string) (*Category, error) { + + category := "" + + if strings.HasPrefix(modelName, "gemini") { + category = "gemini" + } else if strings.HasPrefix(modelName, "claude") { + category = "claude" + } + + if category == "" { + return nil, errors.New("category_not_found") + } + + return CategoryMap[category], nil + +} + +type ChatCompletionConvert func(*types.ChatCompletionRequest) (any, *types.OpenAIErrorWithStatusCode) +type ChatCompletionResponse func(base.ProviderInterface, *http.Response, *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) + +type ChatCompletionStreamResponse func(base.ProviderInterface, *types.ChatCompletionRequest) requester.HandlerPrefix[string] diff --git a/providers/vertexai/category/claude.go b/providers/vertexai/category/claude.go new file mode 100644 index 000000000..215f0450b --- /dev/null +++ b/providers/vertexai/category/claude.go @@ -0,0 +1,87 @@ +package category + +import ( + "encoding/json" + "net/http" + "one-api/common" + "one-api/common/requester" + "one-api/providers/base" + "one-api/providers/claude" + "one-api/types" +) + +const anthropicVersion = "vertex-2023-10-16" + +type ClaudeRequest struct { + *claude.ClaudeRequest + AnthropicVersion string `json:"anthropic_version"` +} + +var claudeMap = map[string]string{ + "claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620", + "claude-3-opus-20240229": "claude-3-opus@20240229", + "claude-3-sonnet-20240229": "claude-3-sonnet@20240229", + "claude-3-haiku-20240307": "claude-3-haiku@20240307", +} + +func init() { + CategoryMap["claude"] = &Category{ + ChatComplete: ConvertClaudeFromChatOpenai, + ResponseChatComplete: ConvertClaudeToChatOpenai, + ResponseChatCompleteStrem: ClaudeChatCompleteStrem, + ErrorHandler: claude.RequestErrorHandle, + GetModelName: GetClaudeModelName, + GetOtherUrl: getClaudeOtherUrl, + } +} + +func ConvertClaudeFromChatOpenai(request *types.ChatCompletionRequest) (any, *types.OpenAIErrorWithStatusCode) { + rawRequest, err := claude.ConvertFromChatOpenai(request) + if err != nil { + return nil, err + } + + claudeRequest := &ClaudeRequest{} + claudeRequest.ClaudeRequest = rawRequest + claudeRequest.AnthropicVersion = anthropicVersion + + // 删除model字段 + claudeRequest.Model = "" + + return claudeRequest, nil +} + +func ConvertClaudeToChatOpenai(provider base.ProviderInterface, response *http.Response, request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) { + claudeResponse := &claude.ClaudeResponse{} + err := json.NewDecoder(response.Body).Decode(claudeResponse) + if err != nil { + return nil, common.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError) + } + + return claude.ConvertToChatOpenai(provider, claudeResponse, request) +} + +func ClaudeChatCompleteStrem(provider base.ProviderInterface, request *types.ChatCompletionRequest) requester.HandlerPrefix[string] { + chatHandler := &claude.ClaudeStreamHandler{ + Usage: provider.GetUsage(), + Request: request, + Prefix: `data: {"type"`, + } + + return chatHandler.HandlerStream +} + +func GetClaudeModelName(modelName string) string { + if value, exists := claudeMap[modelName]; exists { + modelName = value + } + + return modelName +} + +func getClaudeOtherUrl(stream bool) string { + if stream { + return "streamRawPredict" + } + return "rawPredict" +} diff --git a/providers/vertexai/category/gemini.go b/providers/vertexai/category/gemini.go new file mode 100644 index 000000000..6cbd50eb9 --- /dev/null +++ b/providers/vertexai/category/gemini.go @@ -0,0 +1,63 @@ +package category + +import ( + "encoding/json" + "net/http" + "one-api/common" + "one-api/common/requester" + "one-api/providers/base" + "one-api/providers/gemini" + "one-api/types" +) + +func init() { + CategoryMap["gemini"] = &Category{ + ChatComplete: ConvertGeminiFromChatOpenai, + ResponseChatComplete: ConvertGeminiToChatOpenai, + ResponseChatCompleteStrem: GeminiChatCompleteStrem, + ErrorHandler: gemini.RequestErrorHandle, + GetModelName: GetGeminiModelName, + GetOtherUrl: getGeminiOtherUrl, + } +} + +func ConvertGeminiFromChatOpenai(request *types.ChatCompletionRequest) (any, *types.OpenAIErrorWithStatusCode) { + geminiRequest, err := gemini.ConvertFromChatOpenai(request) + if err != nil { + return nil, err + } + + return geminiRequest, nil +} + +func ConvertGeminiToChatOpenai(provider base.ProviderInterface, response *http.Response, request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) { + geminiResponse := &gemini.GeminiChatResponse{} + err := json.NewDecoder(response.Body).Decode(geminiResponse) + if err != nil { + return nil, common.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError) + } + + return gemini.ConvertToChatOpenai(provider, geminiResponse, request) +} + +func GeminiChatCompleteStrem(provider base.ProviderInterface, request *types.ChatCompletionRequest) requester.HandlerPrefix[string] { + chatHandler := &gemini.GeminiStreamHandler{ + Usage: provider.GetUsage(), + LastCandidates: 0, + LastType: "", + Request: request, + } + + return chatHandler.HandlerStream +} + +func GetGeminiModelName(modelName string) string { + return modelName +} + +func getGeminiOtherUrl(stream bool) string { + if stream { + return "streamGenerateContent?alt=sse" + } + return "generateContent" +} diff --git a/providers/vertexai/chat.go b/providers/vertexai/chat.go new file mode 100644 index 000000000..0321b48c2 --- /dev/null +++ b/providers/vertexai/chat.go @@ -0,0 +1,80 @@ +package vertexai + +import ( + "net/http" + "one-api/common" + "one-api/common/requester" + "one-api/providers/vertexai/category" + "one-api/types" +) + +func (p *VertexAIProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) { + // 发送请求 + response, errWithCode := p.Send(request) + if errWithCode != nil { + return nil, errWithCode + } + + defer response.Body.Close() + + return p.Category.ResponseChatComplete(p, response, request) +} + +func (p *VertexAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { + // 发送请求 + response, errWithCode := p.Send(request) + if errWithCode != nil { + return nil, errWithCode + } + + return requester.RequestStream(p.Requester, response, p.Category.ResponseChatCompleteStrem(p, request)) +} + +func (p *VertexAIProvider) Send(request *types.ChatCompletionRequest) (*http.Response, *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getChatRequest(request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + // 发送请求 + return p.Requester.SendRequestRaw(req) +} + +func (p *VertexAIProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { + var err error + p.Category, err = category.GetCategory(request.Model) + if err != nil || p.Category.ChatComplete == nil || p.Category.ResponseChatComplete == nil { + return nil, common.StringErrorWrapper("vertexAI provider not found", "vertexAI_err", http.StatusInternalServerError) + } + + otherUrl := p.Category.GetOtherUrl(request.Stream) + modelName := p.Category.GetModelName(request.Model) + + // 获取请求地址 + fullRequestURL := p.GetFullRequestURL(modelName, otherUrl) + if fullRequestURL == "" { + return nil, common.ErrorWrapper(nil, "invalid_claude_config", http.StatusInternalServerError) + } + + headers := p.GetRequestHeaders() + + if request.Stream { + headers["Accept"] = "text/event-stream" + } + + bedrockRequest, errWithCode := p.Category.ChatComplete(request) + if errWithCode != nil { + return nil, errWithCode + } + + // 错误处理 + p.Requester.ErrorHandler = RequestErrorHandle(p.Category.ErrorHandler) + + // 创建请求 + req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(bedrockRequest), p.Requester.WithHeader(headers)) + if err != nil { + return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + return req, nil +} diff --git a/providers/vertexai/type.go b/providers/vertexai/type.go new file mode 100644 index 000000000..ef5772e47 --- /dev/null +++ b/providers/vertexai/type.go @@ -0,0 +1,29 @@ +package vertexai + +type Credentials struct { + Type string `json:"type"` + ProjectID string `json:"project_id"` + PrivateKeyID string `json:"private_key_id"` + PrivateKey string `json:"private_key"` + ClientEmail string `json:"client_email"` + ClientID string `json:"client_id"` + AuthURI string `json:"auth_uri"` + TokenURI string `json:"token_uri"` + AuthProviderX509CertURL string `json:"auth_provider_x509_cert_url"` + ClientX509CertURL string `json:"client_x509_cert_url"` + UniverseDomain string `json:"universe_domain"` +} + +type VertexaiErrors []*VertexaiError + +type VertexaiError struct { + Error struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` + } `json:"error"` +} + +func (e *VertexaiErrors) Error() *VertexaiError { + return (*e)[0] +} diff --git a/web/src/constants/ChannelConstants.js b/web/src/constants/ChannelConstants.js index e2bfb839a..46282e4c9 100644 --- a/web/src/constants/ChannelConstants.js +++ b/web/src/constants/ChannelConstants.js @@ -187,6 +187,13 @@ export const CHANNEL_OPTIONS = { value: 41, color: 'default' }, + 42: { + key: 42, + text: 'VertexAI', + value: 42, + color: 'orange', + url: 'https://console.cloud.google.com/' + }, 24: { key: 24, text: 'Azure Speech', diff --git a/web/src/views/Channel/type/Config.js b/web/src/views/Channel/type/Config.js index 2a6cf2f5f..9dfd90436 100644 --- a/web/src/views/Channel/type/Config.js +++ b/web/src/views/Channel/type/Config.js @@ -385,6 +385,17 @@ const typeConfig = { model_mapping: '' }, modelGroup: 'Suno' + }, + 42: { + input: { + models: ['claude-3-opus-20240229', 'claude-3-sonnet-20240229', 'claude-3-haiku-20240307'] + }, + prompt: { + key: '请参考wiki中的文档获取key. https://github.com/MartialBE/one-api/wiki/VertexAI', + other: 'Region|ProjectID', + base_url: '' + }, + modelGroup: 'VertexAI' } };