Skip to content

Commit

Permalink
support session-based training
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuanchen Xu committed Jul 28, 2023
1 parent ef4b99e commit 649aca1
Show file tree
Hide file tree
Showing 5 changed files with 300 additions and 15 deletions.
87 changes: 87 additions & 0 deletions applications/Chat/coati/dataset/conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2023 lm-sys@FastChat
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
from enum import Enum, auto
from typing import List


class SeparatorStyle(Enum):
ADD_EOS_TOKEN = auto()


@dataclasses.dataclass
class Conversation:
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.ADD_EOS_TOKEN
sep: str = "</s>"

skip_next: bool = False

def get_prompt(self):
if self.sep_style == SeparatorStyle.ADD_EOS_TOKEN:
ret = self.system
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ": "
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")

def append_message(self, role, message):
self.messages.append([role, message])

def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret

def copy(self):
return Conversation(system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep)

def dict(self):
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep
}


conv = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
roles=("Human", "Assistant"),
messages=(),
offset=0,
sep_style=SeparatorStyle.ADD_EOS_TOKEN,
sep="</s>",
)

default_conversation = conv
97 changes: 87 additions & 10 deletions applications/Chat/coati/dataset/sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import copy
import random
from dataclasses import dataclass, field
from typing import Callable, Dict, Sequence
from typing import Callable, Dict, Sequence, Tuple

import torch
import torch.distributed as dist
Expand All @@ -25,11 +25,21 @@

from colossalai.logging import get_dist_logger

from .conversation import default_conversation
from .utils import is_rank_0, jload

# The following is a template prompt for a 4-round conversation.
"""
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>
"""
# Please note that we only calculate loss on assistant's answer tokens.

logger = get_dist_logger()

IGNORE_INDEX = -100
DEFAULT_EOS_TOKEN = "</s>"
PROMPT_DICT = {
"prompt_input":
("Below is an instruction that describes a task, paired with an input that provides further context. "
Expand Down Expand Up @@ -107,6 +117,61 @@ def preprocess(
return dict(input_ids=input_ids, labels=labels)


def preprocess_conversation(sources: List[List[Dict]], tokenizer: transformers.PreTrainedTokenizer,
max_length: int) -> Dict:
"""Preprocess the conversation data by tokenizing."""
conversations = []
intermediates = []
for source in sources:
header = f"{default_conversation.system}"
conversation, intermediate = _add_speaker_and_signal(header, source)
conversations.append(conversation)
intermediates.append(intermediate)

conversations_tokenized = _tokenize_fn(conversations, tokenizer, max_length)
input_ids = conversations_tokenized["input_ids"]
targets = copy.deepcopy(input_ids)

assert len(targets) == len(intermediates)
for target, inters in zip(targets, intermediates):
mask = torch.zeros_like(target, dtype=torch.bool)
for inter in inters:
tokenized = _tokenize_fn(inter, tokenizer, max_length)

start_idx = tokenized["input_ids"][0].size(0) - 1
end_idx = tokenized["input_ids"][1].size(0)

mask[start_idx:end_idx] = True
target[~mask] = IGNORE_INDEX

return dict(input_ids=input_ids, labels=targets)


def _add_speaker_and_signal(header: str,
source: List[Dict],
get_conversation: bool = True) -> Tuple[str, List[List[str]]]:
END_SIGNAL = DEFAULT_EOS_TOKEN
conversation = header
intermediate = []
for sentence in source:
from_str = sentence["from"]
if from_str.lower() == "human":
from_str = default_conversation.roles[0]
elif from_str.lower() == "gpt":
from_str = default_conversation.roles[1]
else:
from_str = 'unknown'

value = from_str + ": " + sentence["value"] + END_SIGNAL
if sentence["from"].lower() == "gpt":
start = conversation + from_str + ": "
end = conversation + value
intermediate.append([start, end])
if get_conversation:
conversation += value
return conversation, intermediate


class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""

Expand All @@ -125,15 +190,27 @@ def __init__(self,
list_data_dict = list_data_dict[:max_datasets_size]

logger.info("Formatting inputs...")
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
sources = [
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
for example in list_data_dict
]
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]

logger.info("Tokenizing inputs... This may take some time...")
data_dict = preprocess(sources, targets, tokenizer, max_length)
if "conversations" not in list_data_dict[0]:
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
sources = [
prompt_input.format_map(example)
if example.get("input", "") != "" else prompt_no_input.format_map(example) for example in list_data_dict
]
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]

if is_rank_0():
logger.info("Tokenizing inputs... This may take some time...")

data_dict = preprocess(sources, targets, tokenizer, max_length)
else:
if is_rank_0():
logger.info("Tokenizing inputs... This may take some time...")

sources = [conv["conversations"] for conv in list_data_dict]
data_dict = preprocess_conversation(sources, tokenizer, max_length)

if is_rank_0():
logger.info("Tokenizing finish.")

self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
Expand Down
44 changes: 44 additions & 0 deletions applications/Chat/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- [Table of Contents](#table-of-contents)
- [Install requirements](#install-requirements)
- [Supervised datasets collection](#supervised-datasets-collection)
- [Conversation dataset generation](#conversation-dataset-generation)
- [Stage1 - Supervised instructs tuning](#stage1---supervised-instructs-tuning)
- [Arg List](#arg-list)
- [Stage2 - Training reward model](#stage2---training-reward-model)
Expand Down Expand Up @@ -45,6 +46,49 @@ The following pic shows how we collected the data.
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/data-collect.png" width=500/>
</p>

### Conversation dataset generation

In order to further improve the model's ability to handle multi-turn conversations, we need to include samples with multi-turn conversations in the dataset. However, the samples in InstructWild and Alpaca datasets currently consist of only single-turn conversations, and their dataset organization is not suitable for storing multi-turn conversations. Additionally, after converting the aforementioned datasets, we also need to include multi-turn conversation datasets like ShareGPT, and we should transform them into the training format supported by ColossalChat.

A sample of conversation dataset should have the following fields:

* `type` (str, optional): The type of the data sample.
* `language` (str, optional): The language of the data sample.
* `dataset` (str, optional): The dataset the data sample originates from.
* `conversations` (str, compulsory): Conversation content of the data sample.
* `id` (int, optional): The ID of the data sample.

A simple example:
```json
{
"type": "instruction",
"language": "English",
"dataset": "Alpaca",
"conversations": [
{
"from": "human",
"value": "Give three tips for staying healthy."
},
{
"from": "gpt",
"value": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."
}
],
"id": 1
}
```

> **NOTE:** Only key `conversations` is compulsary for training and other keys serve as metadata. The length of `conversations` varies.
You can run the `examples/generate_conversation_dataset.py` to generate a conversation dataset supported by ColossalChat.

You can use the following cmd to generate conversation dataset.
```
python generate_conversation_dataset.py \
--dataset "All"
--save_path "/path/to/dataset"
```

## Stage1 - Supervised instructs tuning

Stage1 is supervised instructs fine-tuning, which uses the datasets mentioned earlier to fine-tune the model.
Expand Down
79 changes: 79 additions & 0 deletions applications/Chat/examples/generate_conversation_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import argparse
import json

from datasets import load_dataset


def generate_alpaca():
# We can convert dataset with the same format("instruction", "input", "output") as Alpaca into a one-round conversation.
conversation_dataset = []
dataset = load_dataset("tatsu-lab/alpaca", split="train")

instructions = dataset["instruction"]
inputs = dataset["input"]
outputs = dataset["output"]

assert len(instructions) == len(inputs) == len(outputs)

for idx in range(len(instructions)):
human_utterance = instructions[idx] + "\n\n" + inputs[idx] if inputs[idx] else instructions[idx]
human = {"from": "human", "value": human_utterance}

gpt_utterance = outputs[idx]
gpt = {"from": "gpt", "value": gpt_utterance}

conversation = dict(type="instruction", language="English", dataset="Alpaca", conversations=[human, gpt])
conversation_dataset.append(conversation)

return conversation_dataset


def generate_sharegpt():
# ShareGPT data requires less processing.
conversation_dataset = []
dataset = load_dataset("anon8231489123/ShareGPT_Vicuna_unfiltered",
data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json",
split="train")

conversations = dataset["conversations"]

for idx in range(len(conversations)):
for conv in conversations[idx]:
# We don't need markdown and text value.
del conv["markdown"]
del conv["text"]

conversation = dict(type="conversation",
language="Multilingual",
dataset="ShareGPT",
conversations=conversations[idx])
conversation_dataset.append(conversation)

return conversation_dataset


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset',
type=str,
default="All",
choices=["Alpaca", "ShareGPT", "All"],
help="which dataset to convert, All will combine Alpaca and ShareGPT")
parser.add_argument('--save_path', type=str, default="dataset.json", help="path to save the converted dataset")
args = parser.parse_args()

conversation_dataset = []

if args.dataset == "Alpaca":
conversation_dataset.extend(generate_alpaca())
elif args.dataset == "ShareGPT":
conversation_dataset.extend(generate_sharegpt())
else:
conversation_dataset.extend(generate_alpaca())
conversation_dataset.extend(generate_sharegpt())

for idx, sample in enumerate(conversation_dataset):
sample["id"] = idx + 1

with open(args.save_path, mode='w') as f:
json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False)
8 changes: 3 additions & 5 deletions applications/Chat/examples/train_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def train(args):
padding_side="right",
use_fast=False,
)
tokenizer.eos_token = '<\s>'
tokenizer.pad_token = tokenizer.unk_token
tokenizer.eos_token = '</s>'
tokenizer.pad_token = tokenizer.eos_token
else:
raise ValueError(f'Unsupported model "{args.model}"')

Expand Down Expand Up @@ -153,9 +153,7 @@ def train(args):
optim,
num_warmup_steps=math.ceil(max_steps * 0.03),
num_training_steps=max_steps)
strategy_dict = strategy.prepare(
dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)
)
strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
model = strategy_dict['model']
optim = strategy_dict['optimizer']
lr_scheduler = strategy_dict['lr_scheduler']
Expand Down

0 comments on commit 649aca1

Please sign in to comment.