Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor dataset handling to support channel-specific data #2456

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,14 @@ def trainer_train(
if args.train_type == 'ppo':
trainer_kwargs['reward_model'] = reward_model
trainer_kwargs['value_model'] = value_model
if args.use_channel_loss:
channel_dataset_dict = {}
for sample in val_dataset:
channel = sample['channel']
if channel not in channel_dataset_dict:
channel_dataset_dict[channel] = []
channel_dataset_dict[channel].append(sample)
val_dataset = channel_dataset_dict
trainer = trainer_cls(
model=model,
args=training_args,
Expand Down
6 changes: 6 additions & 0 deletions swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,8 @@ class SftArguments(ArgumentsBase):
dataset_seed: Optional[int] = None
dataset_test_ratio: float = 0.01
use_loss_scale: bool = False # for agent
use_channel_loss: bool = False
channel_to_save: Optional[str] = None
loss_scale_config_path: str = 'DEFAULT'
system: Optional[str] = None
tools_prompt: Literal['react_en', 'react_zh', 'toolbench'] = 'react_en'
Expand Down Expand Up @@ -1198,6 +1200,10 @@ def _init_training_args(self) -> None:
kwargs['accelerator_config'] = {'dispatch_batches': False}

metric_for_best_model = 'rouge-l' if self.predict_with_generate else 'loss'
if self.use_channel_loss:
if self.channel_to_save is None:
raise ValueError('Please specify --channel_to_save')
metric_for_best_model = f'{self.channel_to_save}_{metric_for_best_model}'
if hasattr(self, 'rlhf_type') and self.rlhf_type == 'ppo':
metric_for_best_model = None

Expand Down
2 changes: 1 addition & 1 deletion swift/llm/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def new_func(self, *args, **kwargs):

standard_keys = {
'query', 'query_role', 'response', 'rejected_response', 'system', 'history', 'history_roles', 'images', 'objects',
'videos', 'audios', 'tools', 'label'
'videos', 'audios', 'tools', 'label', 'channel'
}


Expand Down
1 change: 1 addition & 0 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
inputs[key] = example.get(key)
if inputs.get('labels') is None:
inputs.pop('loss_scale', None)
inputs['channel'] = example.get('channel', '')
return inputs, tokenizer_kwargs

def _concat_context_list(
Expand Down