-
Notifications
You must be signed in to change notification settings - Fork 0
/
chat_utils.py
65 lines (55 loc) · 2.03 KB
/
chat_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import json
from typing import List, Literal, TypedDict
Role = Literal["user", "assistant"]
class Message(TypedDict):
role: Role
content: str
Dialog = List[Message]
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
def format_tokens(dialogs, tokenizer):
prompt_tokens = []
for dialog in dialogs:
if dialog[0]["role"] == "system":
dialog = [
{
"role": dialog[1]["role"],
"content": B_SYS
+ dialog[0]["content"]
+ E_SYS
+ dialog[1]["content"],
}
] + dialog[2:]
assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
[msg["role"] == "assistant" for msg in dialog[1::2]]
), (
"model only supports 'system','user' and 'assistant' roles, "
"starting with user and alternating (u/a/u/a/u...)"
)
"""
Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs.
Here, we are adding it manually.
"""
dialog_tokens: List[int] = sum(
[
tokenizer.encode(
f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
) + [tokenizer.eos_token_id]
for prompt, answer in zip(dialog[::2], dialog[1::2])
],
[],
)
assert (
dialog[-1]["role"] == "user"
), f"Last message must be from user, got {dialog[-1]['role']}"
dialog_tokens += tokenizer.encode(
f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
)
prompt_tokens.append(dialog_tokens)
return prompt_tokens
def read_dialogs_from_file(file_path):
with open(file_path, 'r') as file:
dialogs = json.load(file)
return dialogs