Skip to content

Commit

Permalink
Merge pull request #25 from rjmacarthy/deepseek
Browse files Browse the repository at this point in the history
support deepseek
  • Loading branch information
rjmacarthy authored Jan 5, 2024
2 parents 4df29f9 + f00733d commit a138f3d
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 26 deletions.
71 changes: 51 additions & 20 deletions src/prompts.ts
Original file line number Diff line number Diff line change
@@ -1,47 +1,56 @@
export const systemMessage = `
<<SYS>>
You are a helpful, respectful and honest coding assistant.
Always reply with using markfown.
For code refactoring use markdown code formatting.
If you are not sure which language formatting to use, use \`typescript\`
<</SYS>>
const systemMesage = `You are a helpful, respectful and honest coding assistant.
Always reply with using markfown.
For code refactoring use markdown code formatting.
If you are not sure which language formatting to use, use \`typescript\`
`

export const explain = (code: string) =>
export const getSystemMessage = (modelType: string) => {
return modelType.includes('deepseek')
? systemMesage
: `<<SYS>>
${systemMesage}
<</SYS>>`
}

export const explain = (code: string, modelType: string) =>
`
${systemMessage}
${getSystemMessage(modelType)}
Explain the following code \`\`\`${code}\`\`\` do not waffle on.
`

export const addTypes = (code: string) =>
export const addTypes = (code: string, modelType: string) =>
`
${systemMessage}
${getSystemMessage(modelType)}
Add types to the following code, keep the code the same just add the types \`\`\`${code}\`\`\`.
`

export const refactor = (code: string) =>
export const refactor = (code: string, modelType: string) =>
`
${systemMessage}
${getSystemMessage(modelType)}
Refactor the following code \`\`\`${code}\`\`\` do not change how it works.
Always reply with markdown for code blocks formatting e.g if its typescript use \`typescript\` or \`python\`.
If you are not sure which language this is add \`typescript\`
`

export const addTests = (code: string) =>
export const addTests = (code: string, modelType: string) =>
`
${systemMessage}
${getSystemMessage(modelType)}
Write unit tests for the following \`\`\`${code}\`\`\` use the most popular testing library for the inferred language.
`

export const generateDocs = (code: string) =>
export const generateDocs = (code: string, modelType: string) =>
`
${systemMessage}
${getSystemMessage(modelType)}
Generate documentation \`\`\`${code}\`\`\` use the most popular documentation for the inferred language e.g JSDoc for JavaScript.
`

export const chatMessage = (messages: Message[], selection: string) =>
export const chatMessageLlama = (
messages: Message[],
selection: string,
modelType: string
) =>
`
${systemMessage}
${messages.length === 1 ? getSystemMessage(modelType) : ''}
${messages
.map((message) =>
Expand All @@ -53,8 +62,30 @@ export const chatMessage = (messages: Message[], selection: string) =>
)
.join('\n')}
`

export const chatMessageDeepSeek = (
messages: Message[],
selection: string,
modelType: string
) =>
`
${messages.length === 1 ? getSystemMessage(modelType) : ''}
${messages
.map((message) =>
message.role === 'user'
? `### Instruction:
${message.content} ${selection ? ` \`\`\`${selection}\`\`\` ` : ''}`
: `
### Response:
${message.content}
<|EOT|>
`
)
.join('\n')}
`
interface Prompts {
[key: string]: (code: string) => string
[key: string]: (code: string, modelType: string) => string
}

export const codeActionTypes = ['add-types', 'refactor']
Expand Down
4 changes: 4 additions & 0 deletions src/providers/completion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ export class CompletionProvider implements InlineCompletionItemProvider {
private getPrompt(document: TextDocument, position: Position) {
const { prefix, suffix } = this.getContext(document, position)

if (this._model.includes('deepseek')) {
return `<|fim▁begin|>${prefix}<|fim▁hole|>${suffix}<|fim▁end|>`
}

return `<PRE> ${prefix} <SUF> ${suffix} <MID>`
}

Expand Down
27 changes: 22 additions & 5 deletions src/providers/sidebar.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import * as vscode from 'vscode'
import { chatCompletion, getTextSelection, openDiffView } from '../utils'
import { chatMessage } from '../prompts'
import { chatMessageDeepSeek, chatMessageLlama } from '../prompts'
import { getContext } from '../context'

export class SidebarProvider implements vscode.WebviewViewProvider {
view?: vscode.WebviewView
_doc?: vscode.TextDocument
private _config = vscode.workspace.getConfiguration('twinny')
private _model = this._config.get('chatModelName') as string

constructor(private readonly _extensionUri: vscode.Uri) {}

Expand Down Expand Up @@ -37,11 +39,23 @@ export class SidebarProvider implements vscode.WebviewViewProvider {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(data: any) => {
const context = getContext()
const modelType = this._model.includes('llama') ? 'llama' : 'deepseek'

if (data.type === 'chatMessage') {
chatCompletion('chat', this.view, (selection: string) =>
chatMessage(data.data as Message[], selection)
)
chatCompletion('chat', this.view, (selection: string) => {
if (this._model.includes('deepseek')) {
return chatMessageDeepSeek(
data.data as Message[],
selection,
modelType
)
}
return chatMessageLlama(
data.data as Message[],
selection,
modelType
)
})
}
if (data.type === 'openDiff') {
const editor = vscode.window.activeTextEditor
Expand Down Expand Up @@ -75,7 +89,10 @@ export class SidebarProvider implements vscode.WebviewViewProvider {
if (data.type === 'getTwinnyWorkspaceContext') {
this.view?.webview.postMessage({
type: `twinnyWorkSpaceContext-${data.key}`,
value: context?.workspaceState.get(`twinnyWorkSpaceContext-${data.key}`) || ''
value:
context?.workspaceState.get(
`twinnyWorkSpaceContext-${data.key}`
) || ''
})
}
if (data.type === 'setTwinnyWorkSpaceContext') {
Expand Down
3 changes: 2 additions & 1 deletion src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ export function chatCompletion(
const hostname = config.get('ollamaBaseUrl') as string
const port = config.get('ollamaApiPort') as number
const selection = editor?.selection
const modelType = chatModel.includes('llama') ? 'llama' : 'deepseek'
const text = editor?.document.getText(selection) || ''
const template = prompts[type] ? prompts[type](text) : ''
const template = prompts[type] ? prompts[type](text, modelType) : ''
const prompt: string = template ? template : getPrompt?.(text) || ''

let completion = ''
Expand Down

0 comments on commit a138f3d

Please sign in to comment.