Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Chat] support session-based training #4313

Merged
merged 1 commit into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions applications/Chat/coati/dataset/conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2023 lm-sys@FastChat
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

chengeharrison marked this conversation as resolved.
Show resolved Hide resolved
import dataclasses
from enum import Enum, auto
from typing import List


class SeparatorStyle(Enum):
ADD_EOS_TOKEN = auto()


@dataclasses.dataclass
class Conversation:
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.ADD_EOS_TOKEN
sep: str = "</s>"

skip_next: bool = False

def get_prompt(self):
if self.sep_style == SeparatorStyle.ADD_EOS_TOKEN:
ret = self.system
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ": "
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")

def append_message(self, role, message):
self.messages.append([role, message])

def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret

def copy(self):
return Conversation(system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep)

def dict(self):
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep
}


conv = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
roles=("Human", "Assistant"),
messages=(),
offset=0,
sep_style=SeparatorStyle.ADD_EOS_TOKEN,
sep="</s>",
)

default_conversation = conv
97 changes: 87 additions & 10 deletions applications/Chat/coati/dataset/sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import copy
import random
from dataclasses import dataclass, field
from typing import Callable, Dict, Sequence
from typing import Callable, Dict, List, Sequence, Tuple

import torch
import torch.distributed as dist
Expand All @@ -25,11 +25,21 @@

from colossalai.logging import get_dist_logger

from .conversation import default_conversation
from .utils import is_rank_0, jload

# The following is a template prompt for a 4-round conversation.
"""
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.

Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>
"""
# Please note that we only calculate loss on assistant's answer tokens.

logger = get_dist_logger()

IGNORE_INDEX = -100
DEFAULT_EOS_TOKEN = "</s>"
PROMPT_DICT = {
"prompt_input":
("Below is an instruction that describes a task, paired with an input that provides further context. "
Expand Down Expand Up @@ -107,6 +117,61 @@ def preprocess(
return dict(input_ids=input_ids, labels=labels)


def preprocess_conversation(sources: List[List[Dict]], tokenizer: transformers.PreTrainedTokenizer,
max_length: int) -> Dict:
"""Preprocess the conversation data by tokenizing."""
conversations = []
intermediates = []
for source in sources:
header = f"{default_conversation.system}"
conversation, intermediate = _add_speaker_and_signal(header, source)
conversations.append(conversation)
intermediates.append(intermediate)

conversations_tokenized = _tokenize_fn(conversations, tokenizer, max_length)
input_ids = conversations_tokenized["input_ids"]
targets = copy.deepcopy(input_ids)

assert len(targets) == len(intermediates)
for target, inters in zip(targets, intermediates):
mask = torch.zeros_like(target, dtype=torch.bool)
for inter in inters:
tokenized = _tokenize_fn(inter, tokenizer, max_length)

start_idx = tokenized["input_ids"][0].size(0) - 1
end_idx = tokenized["input_ids"][1].size(0)

mask[start_idx:end_idx] = True
target[~mask] = IGNORE_INDEX

return dict(input_ids=input_ids, labels=targets)


def _add_speaker_and_signal(header: str,
source: List[Dict],
get_conversation: bool = True) -> Tuple[str, List[List[str]]]:
END_SIGNAL = DEFAULT_EOS_TOKEN
conversation = header
intermediate = []
for sentence in source:
from_str = sentence["from"]
if from_str.lower() == "human":
from_str = default_conversation.roles[0]
elif from_str.lower() == "gpt":
from_str = default_conversation.roles[1]
else:
from_str = 'unknown'

value = from_str + ": " + sentence["value"] + END_SIGNAL
if sentence["from"].lower() == "gpt":
start = conversation + from_str + ": "
end = conversation + value
intermediate.append([start, end])
if get_conversation:
conversation += value
return conversation, intermediate


class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""

Expand All @@ -125,15 +190,27 @@ def __init__(self,
list_data_dict = list_data_dict[:max_datasets_size]

logger.info("Formatting inputs...")
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
sources = [
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
for example in list_data_dict
]
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]

logger.info("Tokenizing inputs... This may take some time...")
data_dict = preprocess(sources, targets, tokenizer, max_length)
if "conversations" not in list_data_dict[0]:
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
sources = [
prompt_input.format_map(example)
if example.get("input", "") != "" else prompt_no_input.format_map(example) for example in list_data_dict
]
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]

if is_rank_0():
logger.info("Tokenizing inputs... This may take some time...")

data_dict = preprocess(sources, targets, tokenizer, max_length)
else:
if is_rank_0():
logger.info("Tokenizing inputs... This may take some time...")

sources = [conv["conversations"] for conv in list_data_dict]
data_dict = preprocess_conversation(sources, tokenizer, max_length)
chengeharrison marked this conversation as resolved.
Show resolved Hide resolved

if is_rank_0():
logger.info("Tokenizing finish.")

self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
Expand Down
44 changes: 44 additions & 0 deletions applications/Chat/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- [Table of Contents](#table-of-contents)
- [Install requirements](#install-requirements)
- [Supervised datasets collection](#supervised-datasets-collection)
- [Conversation dataset generation](#conversation-dataset-generation)
- [Stage1 - Supervised instructs tuning](#stage1---supervised-instructs-tuning)
- [Arg List](#arg-list)
- [Stage2 - Training reward model](#stage2---training-reward-model)
Expand Down Expand Up @@ -45,6 +46,49 @@ The following pic shows how we collected the data.
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/data-collect.png" width=500/>
</p>

### Conversation dataset generation

In order to further improve the model's ability to handle multi-turn conversations, we need to include samples with multi-turn conversations in the dataset. However, the samples in InstructWild and Alpaca datasets currently consist of only single-turn conversations, and their dataset organization is not suitable for storing multi-turn conversations. Additionally, after converting the aforementioned datasets, we also need to include multi-turn conversation datasets like ShareGPT, and we should transform them into the training format supported by ColossalChat.

A sample of conversation dataset should have the following fields:

* `type` (str, optional): The type of the data sample.
* `language` (str, optional): The language of the data sample.
* `dataset` (str, optional): The dataset the data sample originates from.
* `conversations` (str, compulsory): Conversation content of the data sample.
* `id` (int, optional): The ID of the data sample.

A simple example:
```json
{
"type": "instruction",
"language": "English",
"dataset": "Alpaca",
"conversations": [
{
"from": "human",
"value": "Give three tips for staying healthy."
},
{
"from": "gpt",
"value": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."
}
],
"id": 1
}
```

> **NOTE:** Only key `conversations` is compulsary for training and other keys serve as metadata. The length of `conversations` varies.

You can run the `examples/generate_conversation_dataset.py` to generate a conversation dataset supported by ColossalChat.

You can use the following cmd to generate conversation dataset.
```
python generate_conversation_dataset.py \
--dataset "All"
--save_path "/path/to/dataset"
```

## Stage1 - Supervised instructs tuning

Stage1 is supervised instructs fine-tuning, which uses the datasets mentioned earlier to fine-tune the model.
Expand Down
79 changes: 79 additions & 0 deletions applications/Chat/examples/generate_conversation_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import argparse
import json

from datasets import load_dataset


def generate_alpaca():
# We can convert dataset with the same format("instruction", "input", "output") as Alpaca into a one-round conversation.
conversation_dataset = []
dataset = load_dataset("tatsu-lab/alpaca", split="train")

instructions = dataset["instruction"]
inputs = dataset["input"]
outputs = dataset["output"]

assert len(instructions) == len(inputs) == len(outputs)

for idx in range(len(instructions)):
human_utterance = instructions[idx] + "\n\n" + inputs[idx] if inputs[idx] else instructions[idx]
human = {"from": "human", "value": human_utterance}

gpt_utterance = outputs[idx]
gpt = {"from": "gpt", "value": gpt_utterance}

conversation = dict(type="instruction", language="English", dataset="Alpaca", conversations=[human, gpt])
conversation_dataset.append(conversation)

return conversation_dataset


def generate_sharegpt():
# ShareGPT data requires less processing.
conversation_dataset = []
dataset = load_dataset("anon8231489123/ShareGPT_Vicuna_unfiltered",
data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json",
split="train")

conversations = dataset["conversations"]

for idx in range(len(conversations)):
for conv in conversations[idx]:
# We don't need markdown and text value.
del conv["markdown"]
del conv["text"]

conversation = dict(type="conversation",
language="Multilingual",
dataset="ShareGPT",
conversations=conversations[idx])
conversation_dataset.append(conversation)

return conversation_dataset


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset',
type=str,
default="All",
choices=["Alpaca", "ShareGPT", "All"],
help="which dataset to convert, All will combine Alpaca and ShareGPT")
parser.add_argument('--save_path', type=str, default="dataset.json", help="path to save the converted dataset")
args = parser.parse_args()

conversation_dataset = []

if args.dataset == "Alpaca":
conversation_dataset.extend(generate_alpaca())
elif args.dataset == "ShareGPT":
conversation_dataset.extend(generate_sharegpt())
else:
conversation_dataset.extend(generate_alpaca())
conversation_dataset.extend(generate_sharegpt())

for idx, sample in enumerate(conversation_dataset):
sample["id"] = idx + 1

with open(args.save_path, mode='w') as f:
json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False)
8 changes: 3 additions & 5 deletions applications/Chat/examples/train_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def train(args):
padding_side="right",
use_fast=False,
)
tokenizer.eos_token = '<\s>'
tokenizer.pad_token = tokenizer.unk_token
tokenizer.eos_token = '</s>'
tokenizer.pad_token = tokenizer.eos_token
else:
raise ValueError(f'Unsupported model "{args.model}"')

Expand Down Expand Up @@ -153,9 +153,7 @@ def train(args):
optim,
num_warmup_steps=math.ceil(max_steps * 0.03),
num_training_steps=max_steps)
strategy_dict = strategy.prepare(
dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)
)
strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
model = strategy_dict['model']
optim = strategy_dict['optimizer']
lr_scheduler = strategy_dict['lr_scheduler']
Expand Down
Loading