Skip to content

Commit

Permalink
Fix lint errors
Browse files Browse the repository at this point in the history
  • Loading branch information
hanbitmyths committed Jun 7, 2024
1 parent df91085 commit bc1a595
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def get_model(args: argparse.Namespace):
attn_implementation=("flash_attention_2" if args.device == "cuda" else "sdpa"),
).to(args.target_device)
except Exception as e:
# When flash_attention or sdpa doesn't support a model, it throws an exception.
# Rather than stopping a process, run as eager mode.
print("Try to load a model using eager mode: ", e)
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
Expand Down Expand Up @@ -115,18 +117,16 @@ def get_model(args: argparse.Namespace):
return model


# When it runs a model without position_ids input, ORT keep printint out a complaint.
# Check if a model has a position_ids input and suppress if not.
def has_position_ids(args):
if args.benchmark_type != "ort":
return True

import onnx
import sys

model = onnx.load(args.onnx_model_path, load_external_data=False)
for input in model.graph.input:
if input.name == "position_ids":
return True
return False
return any(input.name == "position_ids" for input in model.graph.input)


def run_inference(args, model, runs, inputs, outputs):
Expand Down Expand Up @@ -356,6 +356,8 @@ def get_args():
setattr(args, "engine", engine) # noqa: B010
setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010

args.use_buffer_share = args.use_buffer_share and engine == "ort"

return args


Expand Down Expand Up @@ -607,7 +609,7 @@ def main():
)
all_csv_metrics.append(csv_metrics)

except Exception as e: # noqa: E722
except Exception as e:
logger.info(f"Could not benchmark at batch size = {batch_size}, prompt length = {prompt_length} - {e}")

filename = f"benchmark_{args.engine}_e2e_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,15 +466,15 @@ def get_initial_inputs_and_outputs(
past_key = torch.zeros(
batch_size,
num_heads,
max_sequence_length if engine == "ort" and use_buffer_share else 0,
max_sequence_length if use_buffer_share else 0,
head_size,
device=device,
dtype=torch_dtype,
)
past_value = torch.zeros(
batch_size,
num_heads,
max_sequence_length if engine == "ort" and use_buffer_share else 0,
max_sequence_length if use_buffer_share else 0,
head_size,
device=device,
dtype=torch_dtype,
Expand Down

0 comments on commit bc1a595

Please sign in to comment.