Skip to content

Commit

Permalink
add multiple providers in addition to anthropic
Browse files Browse the repository at this point in the history
  • Loading branch information
rrfaria committed Oct 13, 2024
1 parent ffa9f11 commit db3e97e
Show file tree
Hide file tree
Showing 12 changed files with 124 additions and 16 deletions.
3 changes: 3 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ pnpm install

```
ANTHROPIC_API_KEY=XXX
PROVIDER= gemini | antrophic | openai | ollama
MODEL_NAME=XXX
GOOGLE_GENERATIVE_AI_API_KEY=XXX
```

Optionally, you can set the debug level:
Expand Down
6 changes: 6 additions & 0 deletions app/lib/.server/llm/api-key.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,11 @@ export function getAPIKey(cloudflareEnv: Env) {
* The `cloudflareEnv` is only used when deployed or when previewing locally.
* In development the environment variables are available through `env`.
*/
const provider = cloudflareEnv.PROVIDER || 'anthropic';

if (provider === 'gemini') {
return cloudflareEnv.GOOGLE_GENERATIVE_AI_API_KEY || (env.GOOGLE_GENERATIVE_AI_API_KEY as string);
}

return env.ANTHROPIC_API_KEY || cloudflareEnv.ANTHROPIC_API_KEY;
}
25 changes: 25 additions & 0 deletions app/lib/.server/llm/get-model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import type { ModelFactory } from './providers/modelFactory';
import { AnthropicFactory } from './providers/anthropic';
import { OpenAiFactory } from './providers/openAi';
import { GeminiFactory } from './providers/gemini';
import { OllamaFactory } from './providers/ollama';

export function getModelFactory(provider: string): ModelFactory {
switch (provider.toLowerCase()) {
case 'anthropic': {
return new AnthropicFactory();
}
case 'openai': {
return new OpenAiFactory();
}
case 'gemini': {
return new GeminiFactory();
}
case 'ollama': {
return new OllamaFactory();
}
default: {
throw new Error(`Unsupported provider: ${provider}`);
}
}
}
9 changes: 0 additions & 9 deletions app/lib/.server/llm/model.ts

This file was deleted.

16 changes: 16 additions & 0 deletions app/lib/.server/llm/providers/anthropic.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { createAnthropic } from '@ai-sdk/anthropic';
import type { ModelFactory } from './modelFactory';

export function getAnthropicModel(apiKey: string, modelName: string = 'claude-3-5-sonnet-20240620') {
const anthropic = createAnthropic({
apiKey,
});

return anthropic(modelName);
}

export class AnthropicFactory implements ModelFactory {
createModel(apiKey: string, modelName: string) {
return getAnthropicModel(apiKey, modelName);
}
}
16 changes: 16 additions & 0 deletions app/lib/.server/llm/providers/gemini.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { createGoogleGenerativeAI } from '@ai-sdk/google';
import type { ModelFactory } from './modelFactory';

export function getGeminiModel(apiKey: string, modelName: string = 'gemini-1.5-pro-latest') {
const model = createGoogleGenerativeAI({
apiKey,
});

return model(modelName);
}

export class GeminiFactory implements ModelFactory {
createModel(apiKey: string, modelName: string) {
return getGeminiModel(apiKey, modelName);
}
}
4 changes: 4 additions & 0 deletions app/lib/.server/llm/providers/modelFactory.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import type { LanguageModel } from 'ai';
export interface ModelFactory {
createModel(apiKey: string, modelName: string): LanguageModel;
}
15 changes: 15 additions & 0 deletions app/lib/.server/llm/providers/ollama.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import { createOllama } from 'ollama-ai-provider';
import type { ModelFactory } from './modelFactory';

export function getOllamaModel(apiKey: string, modelName: string = 'llama3.2:latest') {
const model = createOllama({
baseURL: 'http://172.21.208.1:11434',
});
return model(modelName);
}

export class OllamaFactory implements ModelFactory {
createModel(apiKey: string, modelName: string) {
return getOllamaModel(apiKey, modelName);
}
}
16 changes: 16 additions & 0 deletions app/lib/.server/llm/providers/openAi.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { createOpenAI } from '@ai-sdk/openai';
import type { ModelFactory } from './modelFactory';

export function getOpenAiModel(apiKey: string, modelName: string = 'gpt-4o-mini') {
const model = createOpenAI({
apiKey,
});

return model(modelName);
}

export class OpenAiFactory implements ModelFactory {
createModel(apiKey: string, modelName: string) {
return getOpenAiModel(apiKey, modelName);
}
}
19 changes: 14 additions & 5 deletions app/lib/.server/llm/stream-text.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { streamText as _streamText, convertToCoreMessages } from 'ai';
import { getAPIKey } from '~/lib/.server/llm/api-key';
import { getAnthropicModel } from '~/lib/.server/llm/model';
import { getModelFactory } from '~/lib/.server/llm/get-model';
import { MAX_TOKENS } from './constants';
import { getSystemPrompt } from './prompts';

Expand All @@ -22,13 +22,22 @@ export type Messages = Message[];
export type StreamingOptions = Omit<Parameters<typeof _streamText>[0], 'model'>;

export function streamText(messages: Messages, env: Env, options?: StreamingOptions) {
const provider = env.PROVIDER || 'anthropic';
const modelName = env.MODEL_NAME || 'default-model';
const factory = getModelFactory(provider);

const model = factory.createModel(getAPIKey(env), modelName);

return _streamText({
model: getAnthropicModel(getAPIKey(env)),
model,
system: getSystemPrompt(),
maxTokens: MAX_TOKENS,
headers: {
'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15',
},
headers:
provider === 'anthropic'
? {
'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15',
}
: undefined,
messages: convertToCoreMessages(messages),
...options,
});
Expand Down
8 changes: 6 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
"node": ">=18.18.0"
},
"dependencies": {
"@ai-sdk/anthropic": "^0.0.39",
"@ai-sdk/anthropic": "^0.0.51",
"@ai-sdk/google": "^0.0.51",
"@ai-sdk/openai": "^0.0.66",
"@codemirror/autocomplete": "^6.17.0",
"@codemirror/commands": "^6.6.0",
"@codemirror/lang-cpp": "^6.0.2",
Expand Down Expand Up @@ -54,14 +56,16 @@
"@xterm/addon-fit": "^0.10.0",
"@xterm/addon-web-links": "^0.11.0",
"@xterm/xterm": "^5.5.0",
"ai": "^3.3.4",
"ai": "^3.4.9",
"date-fns": "^3.6.0",
"diff": "^5.2.0",
"framer-motion": "^11.2.12",
"isbot": "^4.1.0",
"istextorbinary": "^9.5.0",
"jose": "^5.6.3",
"nanostores": "^0.10.3",
"ollama": "^0.5.9",
"ollama-ai-provider": "^0.15.1",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-hotkeys-hook": "^4.5.0",
Expand Down
3 changes: 3 additions & 0 deletions worker-configuration.d.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
interface Env {
ANTHROPIC_API_KEY: string;
PROVIDER: string;
MODEL_NAME: string;
GOOGLE_GENERATIVE_AI_API_KEY: string;
}

0 comments on commit db3e97e

Please sign in to comment.