Skip to content

Commit

Permalink
align all peft
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Aug 22, 2024
1 parent 6ebea46 commit f98999c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions tests/peft/peft_alignment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ def compare(hf_tensor, ff_tensor, label="", tolerance=1e-4):
hf_finetuned_weight = get_hf_tensor(hf_finetuned_weight_name)
torch.testing.assert_close(hf_gradient, (hf_original_weight-hf_finetuned_weight)/learning_rate, rtol=1.3e-6, atol=1e-5)
ff_gradient_name = convert_hf_filename_to_ff(hf_gradient_name)
ff_gradient = get_ff_tensor(ff_gradient_name, hf_gradient.shape, tp_type=TPType.TO_REDUCE)
ff_gradient = get_ff_tensor(ff_gradient_name, hf_gradient.shape, tp_type=TPType.REPLICATE)
compare(hf_gradient, ff_gradient, label=f"LoRA_B {i} gradient")
# ff_out_gradient_name = f"layers.{i}.layers.{i}.mlp.down_proj.lora.output_gradient_0"
# ff_fwd_folder = os.path.join(ff_path, "fwd", f"step_{step_idx}", "shard_0")
Expand Down Expand Up @@ -708,7 +708,7 @@ def compare(hf_tensor, ff_tensor, label="", tolerance=1e-4):
hf_finetuned_weight = get_hf_tensor(hf_finetuned_weight_name)
torch.testing.assert_close(hf_gradient, (hf_original_weight-hf_finetuned_weight)/learning_rate, rtol=1.3e-6, atol=1e-5)
ff_gradient_name = convert_hf_filename_to_ff(hf_gradient_name)
ff_gradient = get_ff_tensor(ff_gradient_name, hf_gradient.shape, tp_type=TPType.TO_REDUCE)
ff_gradient = get_ff_tensor(ff_gradient_name, hf_gradient.shape, tp_type=TPType.PARTITION)
compare(hf_gradient, ff_gradient, label=f"LoRA_A {i} gradient")

parser = argparse.ArgumentParser(description='Argument Parser Example')
Expand Down
2 changes: 1 addition & 1 deletion tests/peft_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ python ./tests/peft/hf_finetune.py --peft-model-id goliaro/llama-160m-lora --sav
# Python test
python ./inference/python/ff_peft.py
# Check alignment
python ./tests/peft/peft_alignment_test.py
python ./tests/peft/peft_alignment_test.py -tp 2

# C++ test
./build/inference/peft/peft \
Expand Down

0 comments on commit f98999c

Please sign in to comment.