Skip to content

Commit

Permalink
Merge pull request #36 from DevoteamNL/feature/stream-ai-response
Browse files Browse the repository at this point in the history
Feature/stream ai response
  • Loading branch information
hardik-id authored May 14, 2024
2 parents 53f9a4f + b46638c commit 9412655
Show file tree
Hide file tree
Showing 10 changed files with 310 additions and 121 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,7 @@ fabric.properties

/.idea/copilot/*
```

# Development settings
.devcontainer/
.vscode/
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"test:e2e": "jest --config ./test/jest-e2e.json"
},
"dependencies": {
"@azure/openai": "1.0.0-beta.6",
"@azure/openai": "1.0.0-beta.12",
"@azure/search-documents": "^12.0.0",
"@nestjs/common": "^9.4.3",
"@nestjs/config": "^2.3.1",
Expand Down
6 changes: 4 additions & 2 deletions src/message/message.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { Message } from './entities/message.entity';
import { Thread } from '../thread/entities/thread.entity';
import { In, Repository } from 'typeorm';
import { InjectRepository } from '@nestjs/typeorm';
import { ChatMessage } from '@azure/openai';
import { ChatResponseMessage } from '@azure/openai';

@Injectable()
export class MessageService {
Expand All @@ -30,7 +30,9 @@ export class MessageService {
return `This action removes a #${id} message`;
}

async findAllMessagesByThreadId(threadId: number): Promise<ChatMessage[]> {
async findAllMessagesByThreadId(
threadId: number,
): Promise<ChatResponseMessage[]> {
const messages: Message[] = await this.messageRepository.find({
where: { thread: { id: threadId } },
order: { id: 'ASC' },
Expand Down
15 changes: 9 additions & 6 deletions src/openai/chat-history.service.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import { Injectable } from '@nestjs/common';
import { ChatMessage } from '@azure/openai';
import { ChatResponseMessage } from '@azure/openai';
import { MessageService } from '../message/message.service';

@Injectable()
export class ChatHistoryService {
private chatHistories = new Map<string, ChatMessage[]>();
private chatHistories = new Map<string, ChatResponseMessage[]>();

constructor(private messageService: MessageService) {}

async initChatHistory(threadId: string): Promise<ChatMessage[]> {
async initChatHistory(threadId: string): Promise<ChatResponseMessage[]> {
if (!this.chatHistories.has(threadId)) {
const history = await this.messageService.findAllMessagesByThreadId(
+threadId,
Expand All @@ -18,7 +18,10 @@ export class ChatHistoryService {
return this.chatHistories.get(threadId);
}

async addMessage(threadId: string, message: ChatMessage): Promise<void> {
async addMessage(
threadId: string,
message: ChatResponseMessage,
): Promise<void> {
// Retrieve the current chat history for the thread
const history = this.chatHistories.get(threadId) || [];

Expand All @@ -37,7 +40,7 @@ export class ChatHistoryService {
}
}

addSystemMessage(threadId: string, systemMessage: ChatMessage): void {
addSystemMessage(threadId: string, systemMessage: ChatResponseMessage): void {
const history = this.chatHistories.get(threadId) || [];
history.unshift(systemMessage); // Prepend system message to the chat history

Expand All @@ -53,7 +56,7 @@ export class ChatHistoryService {
this.chatHistories.set(threadId, updatedHistory);
}

getChatHistory(threadId: string): ChatMessage[] {
getChatHistory(threadId: string): ChatResponseMessage[] {
return this.chatHistories.get(threadId) || [];
}
}
214 changes: 166 additions & 48 deletions src/openai/openai-chat.service.ts
Original file line number Diff line number Diff line change
@@ -1,42 +1,83 @@
import { Inject, Injectable, Logger } from '@nestjs/common';
import { ChatMessage } from '@azure/openai';
import { Injectable, Logger } from '@nestjs/common';
import {
ChatResponseMessage,
ChatRequestFunctionMessage,
EventStream,
ChatCompletions,
} from '@azure/openai';
import { MessageService } from '../message/message.service';
import { AzureOpenAIClientService } from './azure-openai-client.service';
import { PluginService } from 'src/plugin';
import { Message } from '../message/entities/message.entity';
import { ServerResponse } from 'http';

export enum MetadataTagName {
USER_MESSAGE_ID = 'userMessageId',
USER_MESSAGE_CREATED_AT = 'userMessageCreatedAt',
THREAD_ID = 'threadId',
ROLE = 'role',
AI_MESSAGE_ID = 'aiMessageId',
AI_MESSAGE_CREATED_AT = 'aiMessageCreatedAt',
}

interface MetadataContent {
data: string;
metadataTag: MetadataTagName;
}

/**
* Writes a sequence of metadata to the server response stream
* @param writableStream server response
* @param dataForChunks metadata to be written
*/
const writeMetadataToStream = (
writableStream: ServerResponse,
dataForChunks: MetadataContent[],
) => {
for (const dataObj of dataForChunks) {
writableStream.write(`[[${dataObj.metadataTag}=${dataObj.data}]]`);
}
};

@Injectable()
export class OpenaiChatService {
private readonly logger = new Logger(OpenaiChatService.name);
private readonly gpt35Deployment = 'gpt-35-turbo';
private readonly gpt4Deployment = 'gpt-4';
private readonly gpt4_32K_Deployment = 'gpt-4-32k';

constructor(
private readonly messageService: MessageService,
private readonly azureOpenAIClient: AzureOpenAIClientService,
private readonly pluginService: PluginService,
) {}

// Get the employees professional work experience details based on a given employee name or certificate name or skill name

//system message
//Query message from user
//funiton informatin
async getChatResponse({
/**
* Sets and returns what the AI chat responded to the user request
* @param param0
*/
async getChatResponseStream({
senderName,
senderEmail,
threadId,
plugin,
}): Promise<Message> {
writableStream,
userMessageId,
userMessageCreatedAt,
}: {
senderName: string;
senderEmail: string;
threadId: number;
plugin?: string;
writableStream: ServerResponse;
userMessageId: number;
userMessageCreatedAt: string;
}) {
// Initialize the message array with existing messages or an empty array
const chatHistory = await this.messageService.findAllMessagesByThreadId(
threadId,
);
const chatHistory: Array<ChatRequestFunctionMessage | ChatResponseMessage> =
await this.messageService.findAllMessagesByThreadId(threadId);

try {
// Initialize chat session with System message
// Generic prompt engineering
const systemMessage: ChatMessage = {
const systemMessage: ChatResponseMessage = {
role: 'system',
content: `Current Date and Time is ${new Date().toISOString()}.
User's name is ${senderName} and user's emailID is ${senderEmail}.
Expand All @@ -55,8 +96,9 @@ If user just says Hi or how are you to start conversation, you can respond with
};
chatHistory.unshift(systemMessage);
this.logger.log(`CHAT_HISTORY: ${JSON.stringify(chatHistory)}`);

const completion = await this.azureOpenAIClient.getChatCompletions(
this.gpt4Deployment,
this.gpt4_32K_Deployment,
chatHistory,
{
temperature: 0.1,
Expand All @@ -65,66 +107,142 @@ If user just says Hi or how are you to start conversation, you can respond with
),
},
);
const initial_response = completion.choices[0].message;
const initial_response_message = await this.messageService.create({
const initialResponse = completion.choices[0].message;
const initialResponseMessagePromise = this.messageService.create({
threadId,
data: initial_response,
data: initialResponse,
});
chatHistory.push(initial_response);

chatHistory.push(initialResponse);
this.logger.log(
`INITIAL_RESPONSE: ${JSON.stringify(completion.choices[0].message)}`,
);
const functionCall = initial_response.functionCall;

const { functionCall } = initialResponse;
this.logger.log(`FUNCTION_CALLING: ${JSON.stringify(functionCall)}`);

writeMetadataToStream(writableStream, [
{
data: threadId.toString(),
metadataTag: MetadataTagName.THREAD_ID,
},
{
data: userMessageId.toString(),
metadataTag: MetadataTagName.USER_MESSAGE_ID,
},
{
data: new Date(userMessageCreatedAt).toISOString(),
metadataTag: MetadataTagName.USER_MESSAGE_CREATED_AT,
},
]);
writableStream.emit('drain');

if (functionCall && functionCall.name) {
const function_response = await this.pluginService.executeFunction(
const functionResponse = await this.pluginService.executeFunction(
functionCall.name,
functionCall.arguments,
senderEmail,
);
const calledFunction = this.pluginService.findDefinition(
functionCall.name,
);
// chatHistory.push({
// role: function_response.role,
// functionCall: {
// name: functionCall.name,
// arguments: function_response.functionCall.arguments,
// },
// content: '',
// });
chatHistory.push({
role: 'function',
name: functionCall.name,
content: function_response.toString() + calledFunction.followUpPrompt,
});
await this.messageService.create({
const creationPromise = this.messageService.create({
threadId,
data: {
role: 'function',
name: functionCall.name,
content:
function_response.toString() + calledFunction.followUpPrompt,
functionResponse.toString() + calledFunction.followUpPrompt,
},
});
chatHistory.push({
role: 'function',
name: functionCall.name,
content: functionResponse.toString() + calledFunction.followUpPrompt,
});

this.logger.debug(`########`);
this.logger.debug(chatHistory);
const final_completion =
await this.azureOpenAIClient.getChatCompletions(
calledFunction.followUpModel || this.gpt35Deployment,

const finalCompletionEventStream: EventStream<ChatCompletions> =
await this.azureOpenAIClient.streamChatCompletions(
calledFunction.followUpModel || this.gpt4_32K_Deployment,
chatHistory,
{ temperature: calledFunction.followUpTemperature || 0 },
);
const final_response: ChatMessage = final_completion.choices[0].message;
this.logger.log(`final_response Response:`);
this.logger.log(final_response);
chatHistory.push(final_response);
return await this.messageService.create({

const NO_CONTENT = '';
const responseMessage: ChatResponseMessage = {
content: NO_CONTENT,
role: NO_CONTENT,
};

for await (const event of finalCompletionEventStream) {
for (let i = 0; i < event.choices.length; i++) {
if (event.choices[i].delta) {
if (
event.choices[i].delta.role &&
responseMessage.role === NO_CONTENT
) {
writeMetadataToStream(writableStream, [
{
data: event.choices[i].delta.role,
metadataTag: MetadataTagName.ROLE,
},
]);
responseMessage.role = event.choices[i].delta.role;
}
if (event.choices[i].delta.content) {
writableStream.write(event.choices[i].delta.content);
responseMessage.content += event.choices[i].delta.content;
}
if (i % 2 === 0) {
writableStream.emit('drain');
}
}
}
}
this.logger.log(responseMessage);
// chatHistory.push(responseMessage);

await creationPromise;
const createdMessage = await this.messageService.create({
threadId,
data: final_response,
data: responseMessage,
});
writeMetadataToStream(writableStream, [
{
data: createdMessage.id.toString(),
metadataTag: MetadataTagName.AI_MESSAGE_ID,
},
{
data: new Date(createdMessage.createdAt).toISOString(),
metadataTag: MetadataTagName.AI_MESSAGE_CREATED_AT,
},
]);
} else {
this.logger.log('No functionCall');

const initialResponseMessage = await initialResponseMessagePromise;
writeMetadataToStream(writableStream, [
{
data: initialResponseMessage.data.role,
metadataTag: MetadataTagName.ROLE,
},
]);
writableStream.write(initialResponseMessage.data.content);
writeMetadataToStream(writableStream, [
{
data: initialResponseMessage.id.toString(),
metadataTag: MetadataTagName.AI_MESSAGE_ID,
},
{
data: new Date(initialResponseMessage.createdAt).toISOString(),
metadataTag: MetadataTagName.AI_MESSAGE_CREATED_AT,
},
]);
}
return initial_response_message;
writableStream.end();
} catch (error) {
this.logger.log(error);
throw error;
Expand Down
Loading

0 comments on commit 9412655

Please sign in to comment.