diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 0abced612..eb7a681d7 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -1325,7 +1325,6 @@ def sample_inputs__scaled_dot_product_efficient_attention( dim_4_q_shape = (batch, num_heads, seq_q, head_dim) dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim) - shape_attn_bias = (batch, num_heads, seq_q, seq_kv) qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)] @@ -1339,7 +1338,7 @@ def sample_inputs__scaled_dot_product_efficient_attention( make(shape_q), make(shape_kv), make(shape_kv), - attn_bias=make(shape_attn_bias), + attn_bias=None, # TODO: Add attn_bias is_causal=is_causal, dropout_p=dropout_p, compute_log_sumexp=compute_log_sumexp,