Skip to content

Commit

Permalink
refactor, Extract the logic to fetch models from the openrouter API i…
Browse files Browse the repository at this point in the history
…nto a helper, so to it can be reused

Part of #296
  • Loading branch information
ruiAzevedo19 committed Jul 30, 2024
1 parent e5c1bcc commit 8ae0ab8
Showing 1 changed file with 44 additions and 8 deletions.
52 changes: 44 additions & 8 deletions provider/openrouter/openrouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ package openrouter

import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/url"
"strings"
"time"

Expand Down Expand Up @@ -81,16 +85,49 @@ type Pricing struct {

// Models returns which models are available to be queried via this provider.
func (p *Provider) Models() (models []model.Model, err error) {
client := p.client()
responseModels, err := p.fetchModels()
if err != nil {
return nil, err
}

models = make([]model.Model, len(responseModels.Models))
for i, model := range responseModels.Models {
models[i] = llm.NewModel(p, p.ID()+provider.ProviderModelSeparator+model.ID)
}

return models, nil
}

// fetchModels returns the list of models of the provider.
func (p *Provider) fetchModels() (models ModelsList, err error) {
modelsURLPath, err := url.JoinPath(p.baseURL, "models")
if err != nil {
return ModelsList{}, pkgerrors.WithStack(err)
}
request, err := http.NewRequest("GET", modelsURLPath, nil)
if err != nil {
return ModelsList{}, pkgerrors.WithStack(err)
}
request.Header.Set("Accept", "application/json")

var responseModels openai.ModelsList
client := &http.Client{}
var responseBody []byte
if err := retry.Do( // Query available models with a retry logic cause "openrouter.ai" has failed us in the past.
func() error {
ms, err := client.ListModels(context.Background())
response, err := client.Do(request)
if err != nil {
return pkgerrors.WithStack(err)
}
defer response.Body.Close()

if response.StatusCode != http.StatusOK {
return pkgerrors.Errorf("received status code %d when querying provider models", response.StatusCode)
}

responseBody, err = io.ReadAll(response.Body)
if err != nil {
return pkgerrors.WithStack(err)
}
responseModels = ms

return nil
},
Expand All @@ -99,12 +136,11 @@ func (p *Provider) Models() (models []model.Model, err error) {
retry.DelayType(retry.BackOffDelay),
retry.LastErrorOnly(true),
); err != nil {
return nil, err
return ModelsList{}, err
}

models = make([]model.Model, len(responseModels.Models))
for i, model := range responseModels.Models {
models[i] = llm.NewModel(p, p.ID()+provider.ProviderModelSeparator+model.ID)
if err = json.Unmarshal(responseBody, &models); err != nil {
return ModelsList{}, pkgerrors.WithStack(err)
}

return models, nil
Expand Down

0 comments on commit 8ae0ab8

Please sign in to comment.