Skip to content

Commit

Permalink
🐛 fix: fix vertex AI credential acquisition failure when using a proxy (
Browse files Browse the repository at this point in the history
  • Loading branch information
MartialBE committed Jul 16, 2024
1 parent 7a6e928 commit 5ded841
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
1 change: 1 addition & 0 deletions common/utils/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type ContextKey string

const ProxyHTTPAddrKey ContextKey = "proxyHttpAddr"
const ProxySock5AddrKey ContextKey = "proxySock5Addr"
const ProxyAddrKey ContextKey = "proxyAddr"

func ProxyFunc(req *http.Request) (*url.URL, error) {
proxyAddr := req.Context().Value(ProxyHTTPAddrKey)
Expand Down
32 changes: 31 additions & 1 deletion providers/vertexai/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"one-api/common/cache"
"one-api/common/logger"
"one-api/common/requester"
"one-api/common/utils"
"one-api/model"
"one-api/providers/base"
"one-api/providers/vertexai/category"
Expand All @@ -19,7 +22,9 @@ import (

credentials "cloud.google.com/go/iam/credentials/apiv1"
"cloud.google.com/go/iam/credentials/apiv1/credentialspb"
"golang.org/x/net/proxy"
"google.golang.org/api/option"
"google.golang.org/grpc"
)

const TokenCacheKey = "api_token:vertexai"
Expand Down Expand Up @@ -104,7 +109,11 @@ func (p *VertexAIProvider) GetToken() (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

client, err := credentials.NewIamCredentialsClient(ctx, option.WithCredentialsJSON([]byte(p.Channel.Key)))
if p.Channel.Proxy != nil && *p.Channel.Proxy != "" {
ctx = context.WithValue(ctx, utils.ProxyAddrKey, *p.Channel.Proxy)
}

client, err := credentials.NewIamCredentialsClient(ctx, option.WithCredentialsJSON([]byte(p.Channel.Key)), option.WithGRPCDialOption(grpc.WithContextDialer(customDialer)))
if err != nil {
return "", fmt.Errorf("failed to create IAM credentials client: %w", err)
}
Expand Down Expand Up @@ -163,3 +172,24 @@ func errorHandle(vertexaiError *VertexaiError) *types.OpenAIError {
Code: vertexaiError.Error.Code,
}
}

func customDialer(ctx context.Context, addr string) (net.Conn, error) {
proxyAddress, ok := ctx.Value(utils.ProxyAddrKey).(string)
if !ok || proxyAddress == "" {
return net.Dial("tcp", addr)
}

proxyURL, err := url.Parse(proxyAddress)
if err != nil {
return nil, fmt.Errorf("error parsing proxy address: %w", err)
}

dialer := &net.Dialer{}

dialerProxy, err := proxy.FromURL(proxyURL, dialer)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP dialer: %v", err)
}

return dialerProxy.Dial("tcp", addr)
}

0 comments on commit 5ded841

Please sign in to comment.