-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresponse_generator.py
37 lines (31 loc) · 942 Bytes
/
response_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
"""
Module for generating LLM response
"""
import cohere
import os
class ResponseGenerator:
"""
Args:
cohere_client: cohere.ClientV2
model: str
"""
def __init__(self, model):
self.cohere_client: cohere.ClientV2 = cohere.ClientV2(
api_key=os.environ["COHERE_API_KEY"]
)
self.model: str = model
def generate_response(self, query, prompt, context):
messages= [
{"role": "system", "content": prompt},
{"role": "user", "content": query},
]
response = self.cohere_client.chat(
messages = messages,
model = self.model,
max_tokens=200,
temperature = 0.1,
# documents= context
)
return response.message.content[0].text
def predict(self, query, prompt, context):
return self.generate_response(query, prompt, context)