Skip to content

Commit

Permalink
Merge branch 'main' into wpo
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif authored Oct 1, 2024
2 parents 84269e0 + 0a566f0 commit e3f9a75
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
9 changes: 6 additions & 3 deletions examples/scripts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,10 +335,13 @@ def chat_cli():

chat.append({"role": "user", "content": user_input})

inputs = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
model.device
)
attention_mask = torch.ones_like(inputs)
generation_kwargs = dict(
inputs=tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
model.device
),
inputs=inputs,
attention_mask=attention_mask,
streamer=generation_streamer,
max_new_tokens=current_args.max_new_tokens,
do_sample=current_args.do_sample,
Expand Down
2 changes: 1 addition & 1 deletion trl/commands/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def chat():
init_zero_verbose()
trl_examples_dir = os.path.dirname(__file__)

command = f"accelerate launch {trl_examples_dir}/scripts/chat.py {' '.join(sys.argv[2:])}"
command = f"python {trl_examples_dir}/scripts/chat.py {' '.join(sys.argv[2:])}"

try:
subprocess.run(
Expand Down

0 comments on commit e3f9a75

Please sign in to comment.