Skip to content

Commit

Permalink
feat: adding support for deepseek-coder models (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
srikanth235 authored Jan 25, 2024
1 parent 6af3f04 commit 7da3289
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 38 deletions.
20 changes: 18 additions & 2 deletions .prettierrc.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,21 @@ module.exports = {
trailingComma: "es5",
tabWidth: 2,
semi: true,
singleQuote: false
};
singleQuote: false,
arrowParens: "always",
bracketSameLine: false,
bracketSpacing: true,
experimentalTernaries: false,
jsxSingleQuote: false,
quoteProps: "as-needed",
singleAttributePerLine: false,
htmlWhitespaceSensitivity: "css",
vueIndentScriptAndStyle: false,
proseWrap: "preserve",
insertPragma: false,
printWidth: 80,
requirePragma: false,
useTabs: false,
embeddedLanguageFormatting: "auto",
spaceBeforeFunctionParen: false
};
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Some of the popular LLMs that we recommend are:

- [Mistral](https://mistral.ai/)
- [CodeLLama](https://github.com/facebookresearch/codellama)
- [DeepSeek Coder](https://github.com/deepseek-ai/DeepSeek-Coder)

## Quick Install

Expand Down
2 changes: 1 addition & 1 deletion lib/extension/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"@privy/common": "*",
"handlebars": "4.7.8",
"marked": "4.2.12",
"modelfusion": "0.113.0",
"modelfusion": "0.122.0",
"secure-json-parse": "2.7.0",
"simple-git": "3.21.0",
"zod": "3.22.4"
Expand Down
63 changes: 39 additions & 24 deletions lib/extension/src/ai/AIClient.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import {
Llama2Prompt,
OpenAITextEmbeddingResponse,
InstructionPrompt,
TextStreamingModel,
Expand All @@ -9,7 +8,6 @@ import {
ollama,
openai,
streamText,
MistralInstructPrompt,
} from "modelfusion";
import * as vscode from "vscode";
import { z } from "zod";
Expand All @@ -30,14 +28,14 @@ function getProviderBaseUrl(): string {
);
}

function getModel() {
return z
function getModel(): string {
let model = z
.enum(["mistral:instruct", "codellama:instruct", "custom"])
.parse(vscode.workspace.getConfiguration("privy").get("model"));
}

function getCustomModel(): string {
return vscode.workspace.getConfiguration("privy").get("customModel", "");
if (model === "custom") {
return vscode.workspace.getConfiguration("privy").get("customModel", "");
}
return model;
}

function getProvider() {
Expand All @@ -46,6 +44,17 @@ function getProvider() {
.parse(vscode.workspace.getConfiguration("privy").get("provider"));
}

function getPromptTemplate() {
const model = getModel();
if (model.startsWith("mistral")) {
return ollama.prompt.Mistral;
} else if (model.startsWith("deepseek")) {
return ollama.prompt.Text;
}

return ollama.prompt.Llama2;
}

export class AIClient {
private readonly apiKeyManager: ApiKeyManager;
private readonly logger: Logger;
Expand Down Expand Up @@ -78,24 +87,29 @@ export class AIClient {
stop?: string[] | undefined;
temperature?: number | undefined;
}): Promise<TextStreamingModel<InstructionPrompt>> {
const modelConfiguration = getModel();
const provider = getProvider();

if (provider.startsWith("llama")) {
return llamacpp
.TextGenerator({
.CompletionTextGenerator({
api: await this.getProviderApiConfiguration(),
// TODO the prompt format needs to be configurable for non-Llama2 models
promptTemplate: llamacpp.prompt.Llama2,
maxGenerationTokens: maxTokens,
stopSequences: stop,
temperature,
})
.withTextPromptTemplate(Llama2Prompt.instruction());
.withInstructionPrompt();
}

return ollama
.ChatTextGenerator({
.CompletionTextGenerator({
api: await this.getProviderApiConfiguration(),
model: getModel() === "custom" ? getCustomModel() : getModel(),
promptTemplate: getPromptTemplate(),
model: getModel(),
maxGenerationTokens: maxTokens,
stopSequences: stop,
temperature,
})
.withInstructionPrompt();
}
Expand All @@ -112,28 +126,29 @@ export class AIClient {
temperature?: number | undefined;
}) {
this.logger.log(["--- Start prompt ---", prompt, "--- End prompt ---"]);

return streamText(
await this.getTextStreamingModel({ maxTokens, stop, temperature }),
{ instruction: prompt }
);
return streamText({
model: await this.getTextStreamingModel({ maxTokens, stop, temperature }),
prompt: {
instruction: prompt,
},
});
}

async generateEmbedding({ input }: { input: string }) {
try {
const { embedding, response } = await embed(
openai.TextEmbedder({
const { embedding, rawResponse } = await embed({
model: openai.TextEmbedder({
api: await this.getProviderApiConfiguration(),
model: "text-embedding-ada-002",
}),
input,
{ fullResponse: true }
);
value: input,
fullResponse: true,
});

return {
type: "success" as const,
embedding,
totalTokenCount: (response as OpenAITextEmbeddingResponse).usage
totalTokenCount: (rawResponse as OpenAITextEmbeddingResponse).usage
.total_tokens,
};
} catch (error: any) {
Expand Down
9 changes: 5 additions & 4 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"lib/*"
],
"dependencies": {
"modelfusion": "^0.122.0",
"pnpm": "^8.13.1"
}
}
7 changes: 0 additions & 7 deletions template/chat/chat-en.rdt.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,6 @@ Developer: {{content}}
{{/if}}
{{/each}}
## Task
Write a response that continues the conversation.
Stay focused on current developer request.
Consider the possibility that there might not be a solution.
Ask for clarification if the message does not make sense or more input is needed.
Omit any links.
Include code snippets (using Markdown) and examples where appropriate.
## Response
Bot:
Expand Down

0 comments on commit 7da3289

Please sign in to comment.