Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Nov 6, 2024
1 parent fc884fe commit 2047bdd
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 33 deletions.
2 changes: 1 addition & 1 deletion tests/fine_grained_alignment_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ CACHE_PATH=${FF_CACHE_PATH:-"~/.cache/flexflow"}
NUM_STEPS=${NUM_STEPS:-2}

cleanup() {
rm -rf "${CACHE_PATH}"/debug ./fine_grained_alignment_config.json ./inference/output/fine_grained_alignment_test_ff.txt ./inference/output/fine_grained_alignment_test_hf.txt
eval rm -rf "${CACHE_PATH}/debug" ./fine_grained_alignment_config.json ./inference/output/fine_grained_alignment_test_ff.txt ./inference/output/fine_grained_alignment_test_hf.txt
}

# Cd into directory holding this script
Expand Down
66 changes: 36 additions & 30 deletions tests/inference/inference_alignment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def check_bwd_pass(self):
def check_step(self, step_idx, learning_rate=0.001):
raise NotImplementedError()

class LllamaAlignmentTest(AlignmentTest):
class LlamaAlignmentTest(AlignmentTest):
def __init__(self, hf_config, tp_degree=1):
self.hf_config = hf_config
self.num_layers = self.hf_config.num_hidden_layers
Expand Down Expand Up @@ -168,7 +168,10 @@ def get_ff_tensor(ff_tensor_name, tensor_comparison_idx, hf_shape, tp_type=TPTyp
ff_tensor = np.loadtxt(ff_tensor_path, delimiter=',')
self.ff_batch_size = ff_tensor.shape[0]

ff_shape = replace_value(ff_shape, self.num_tokens, self.ff_batch_size)
if "lm_head" in ff_tensor_path:
ff_shape = replace_value(ff_shape, 1, self.ff_batch_size)
else:
ff_shape = replace_value(ff_shape, self.num_tokens, self.ff_batch_size)
ff_tensors = [load_ff_tensor(ff_tensor_path.replace("shard_0", f"shard_{tp_idx}"), ff_shape) for tp_idx in range(self.tp_degree)]
if self.tp_degree > 1:
# if replicate, check that they are identical
Expand Down Expand Up @@ -356,11 +359,14 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance
ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name)
input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="input", hf_tensor_idx=0, ff_tensor_idx=0)
hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison)
ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape, tp_type=TPType.REPLICATE)
ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape, tp_type=TPType.REPLICATE)[:,:,-1].squeeze()
hf_tensor = hf_tensor.squeeze()
print(hf_tensor.shape, ff_tensor.shape)
compare(hf_tensor, ff_tensor, label="LM head input")
output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0)
hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison)
ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION)
ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION)[:,:,-1].squeeze()
hf_tensor = hf_tensor.squeeze()
compare(hf_tensor, ff_tensor, label="LM head output")

class OPTAlignmentTest(AlignmentTest):
Expand Down Expand Up @@ -664,17 +670,17 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance
assert torch.allclose(ff_qkv_tensor_out, ff_attn_tensor_in)

# Compared scaled qproj
hf_tensor_name = f"layers.{i}.self_attn.scaled_qproj"
input_c = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0)
output_c = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0)
scaled_qproj_in = get_hf_tensor(hf_tensor_name, input_c)
scaled_qproj_out = get_hf_tensor(hf_tensor_name, output_c)
assert torch.allclose(scaled_qproj_in, scaled_qproj_out)
ff_tensor_name = f"layers.{i}.layers.{i}.self_attn.scaled_qkv_proj"
scaled_qkv_proj0 = load_ff_tensor(os.path.join(ff_fwd_folder, f"{ff_tensor_name}.output_0"), [64*6,3,9])
scaled_qkv_proj1 = load_ff_tensor(os.path.join(ff_fwd_folder, f"{ff_tensor_name}.output_0").replace("shard_0", "shard_1"), [64*6,3,9])
ff_scaled_qkv_proj = np.concatenate([scaled_qkv_proj0, scaled_qkv_proj1], axis=0)
ff_scaled_q_proj = torch.from_numpy(ff_scaled_qkv_proj[:, :1, :]).to(scaled_qproj_out.dtype)
# hf_tensor_name = f"layers.{i}.self_attn.scaled_qproj"
# input_c = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0)
# output_c = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0)
# scaled_qproj_in = get_hf_tensor(hf_tensor_name, input_c)
# scaled_qproj_out = get_hf_tensor(hf_tensor_name, output_c)
# assert torch.allclose(scaled_qproj_in, scaled_qproj_out)
# ff_tensor_name = f"layers.{i}.layers.{i}.self_attn.scaled_qkv_proj"
# scaled_qkv_proj0 = load_ff_tensor(os.path.join(ff_fwd_folder, f"{ff_tensor_name}.output_0"), [64*6,3,9])
# scaled_qkv_proj1 = load_ff_tensor(os.path.join(ff_fwd_folder, f"{ff_tensor_name}.output_0").replace("shard_0", "shard_1"), [64*6,3,9])
# ff_scaled_qkv_proj = np.concatenate([scaled_qkv_proj0, scaled_qkv_proj1], axis=0)
# ff_scaled_q_proj = torch.from_numpy(ff_scaled_qkv_proj[:, :1, :]).to(scaled_qproj_out.dtype)
# print("HF scaled qproj:")
# print(scaled_qproj_out.squeeze().T)
# print("FF scaled q proj:")
Expand All @@ -688,15 +694,15 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance


# check that out_proj input, attn_scores out and input are identical on the hf side
hf_tensor_name = f"layers.{i}.self_attn.attn_scores"
input_c = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0)
output_c = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0)
attn_scores_in = get_hf_tensor(hf_tensor_name, input_c)
attn_scores_out = get_hf_tensor(hf_tensor_name, output_c)
# hf_tensor_name = f"layers.{i}.self_attn.attn_scores"
# input_c = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0)
# output_c = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0)
# attn_scores_in = get_hf_tensor(hf_tensor_name, input_c)
# attn_scores_out = get_hf_tensor(hf_tensor_name, output_c)
hf_tensor_name = f"layers.{i}.self_attn.out_proj"
out_proj_in = get_hf_tensor(hf_tensor_name, input_c)
assert torch.allclose(attn_scores_in, attn_scores_out)
assert torch.allclose(attn_scores_in, out_proj_in)
# out_proj_in = get_hf_tensor(hf_tensor_name, input_c)
# assert torch.allclose(attn_scores_in, attn_scores_out)
# assert torch.allclose(attn_scores_in, out_proj_in)

# Compare out proj input. This should be the output of the attention without any bias involved
hf_tensor_name = f"layers.{i}.self_attn.out_proj"
Expand All @@ -707,12 +713,12 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance
print("comparing attention tensor: ", hf_tensor_name, " and ", ff_tensor_name)
compare(hf_tensor, ff_tensor, label=f"Attention o-proj {i} input")

hf_tensor_name = f"layers.{i}.self_attn.attn_scores"
ff_tensor_name = f"layers.{i}.layers.{i}.self_attn"
output_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0)
hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison)
ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION)
compare(hf_tensor, ff_tensor, label=f"Attention {i} output")
# hf_tensor_name = f"layers.{i}.self_attn.attn_scores"
# ff_tensor_name = f"layers.{i}.layers.{i}.self_attn"
# output_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0)
# hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison)
# ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION)
# compare(hf_tensor, ff_tensor, label=f"Attention {i} output")

# hf_tensor_name = f"layers.{i}.final_layer_norm"
# ff_tensor_name = f"layers.{i}.layers.{i}.add_bias_residual_layer_norm"
Expand Down Expand Up @@ -808,7 +814,7 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance
hf_config = AutoConfig.from_pretrained(args.model_name)
alignment_class = None
if hf_config.architectures[0] == "LlamaForCausalLM":
alignment_class = LllamaAlignmentTest(hf_config, tp_degree=args.tensor_parallelism_degree)
alignment_class = LlamaAlignmentTest(hf_config, tp_degree=args.tensor_parallelism_degree)
elif hf_config.architectures[0] == "OPTForCausalLM":
alignment_class = OPTAlignmentTest(hf_config, tp_degree=args.tensor_parallelism_degree)

Expand Down
4 changes: 2 additions & 2 deletions tests/peft/alignment/align_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,9 @@ class TensorComparisonIdxs:
def replace_value(lst, old_value, new_value):
occurrences = lst.count(old_value)
if occurrences == 0:
raise ValueError(f"Value {old_value} not found in the list.")
raise ValueError(f"Value {old_value} not found in the list: {lst}")
elif occurrences > 1:
warnings.warn(f"Multiple instances of {old_value} found in the list.")
warnings.warn(f"Multiple instances of {old_value} found in the list: {lst}")
occurrence_idx=0
for i, value in enumerate(lst):
if value == old_value:
Expand Down

0 comments on commit 2047bdd

Please sign in to comment.