From fc2e8f67fcd6902ac528c0bc4424e4ceb0514541 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Thu, 16 May 2024 16:47:36 +0200 Subject: [PATCH] feat(bench): add 'num_inference_steps' arg This commit gives users the ability to set the `num_inference_steps` pipeline parameter when running the benchmarking script. --- runner/bench.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/runner/bench.py b/runner/bench.py index c286a9a0..cfc6d319 100644 --- a/runner/bench.py +++ b/runner/bench.py @@ -24,30 +24,32 @@ class BenchMetrics(BaseModel): max_mem_reserved: float -def call_pipeline(pipeline: Pipeline, batch_size=1) -> List[any]: +def call_pipeline(pipeline: Pipeline, batch_size=1, **kwargs) -> List[any]: if isinstance(pipeline, TextToImagePipeline): prompts = [PROMPT] * batch_size - return pipeline(prompts) + return pipeline(prompts, **kwargs) elif isinstance(pipeline, ImageToImagePipeline): prompts = [PROMPT] * batch_size images = [Image.open(IMAGE).convert("RGB")] * batch_size - return pipeline(prompts, images) + return pipeline(prompts, images, **kwargs) elif isinstance(pipeline, ImageToVideoPipeline): images = [Image.open(IMAGE).convert("RGB")] * batch_size - return pipeline(images) + return pipeline(images, **kwargs) else: raise Exception("invalid pipeline") -def bench_pipeline(pipeline: Pipeline, batch_size=1, runs=1) -> BenchMetrics: +def bench_pipeline(pipeline: Pipeline, batch_size=1, runs=1, num_inference_steps=None) -> BenchMetrics: inference_time = np.zeros(runs) inference_time_per_output = np.zeros(runs) max_mem_allocated = np.zeros(runs) max_mem_reserved = np.zeros(runs) + kwargs = {"num_inference_steps": num_inference_steps} if num_inference_steps is not None else {} + for i in range(runs): start = time() - output = call_pipeline(pipeline, batch_size) + output = call_pipeline(pipeline, batch_size, **kwargs) if isinstance(output, tuple): output = output[0] assert len(output) == batch_size @@ -94,13 +96,20 @@ def bench_pipeline(pipeline: Pipeline, batch_size=1, runs=1) -> BenchMetrics: required=False, help="the number of times to call the pipeline", ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=None, + required=False, + help="the number of inference steps to run for the pipeline", + ) parser.add_argument( "--batch_size", type=int, default=1, required=False, help="the size of a batch" ) args = parser.parse_args() - print(f"{args.pipeline=} {args.model_id=} {args.runs=} {args.batch_size=}") + print(f"{args.pipeline=} {args.model_id=} {args.runs=} {args.batch_size=} {args.num_inference_steps=}") start = time() pipeline = load_pipeline(args.pipeline, args.model_id) @@ -113,10 +122,10 @@ def bench_pipeline(pipeline: Pipeline, batch_size=1, runs=1) -> BenchMetrics: # Collect pipeline warmup metrics if stable-fast is enabled if os.getenv("SFAST", "").strip().lower() == "true": warmups = 3 - warmup_metrics = bench_pipeline(pipeline, args.batch_size, warmups) + warmup_metrics = bench_pipeline(pipeline, args.batch_size, warmups, args.num_inference_steps) # Collect pipeline inference metrics - metrics = bench_pipeline(pipeline, args.batch_size, args.runs) + metrics = bench_pipeline(pipeline, args.batch_size, args.runs, args.num_inference_steps) print("\n") print("----AGGREGATE METRICS----")