Skip to content

Commit

Permalink
feat(bench): add 'num_inference_steps' arg
Browse files Browse the repository at this point in the history
This commit gives users the ability to set the `num_inference_steps`
pipeline parameter when running the benchmarking script.
  • Loading branch information
rickstaa committed May 16, 2024
1 parent 272ac74 commit fc2e8f6
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions runner/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,32 @@ class BenchMetrics(BaseModel):
max_mem_reserved: float


def call_pipeline(pipeline: Pipeline, batch_size=1) -> List[any]:
def call_pipeline(pipeline: Pipeline, batch_size=1, **kwargs) -> List[any]:
if isinstance(pipeline, TextToImagePipeline):
prompts = [PROMPT] * batch_size
return pipeline(prompts)
return pipeline(prompts, **kwargs)
elif isinstance(pipeline, ImageToImagePipeline):
prompts = [PROMPT] * batch_size
images = [Image.open(IMAGE).convert("RGB")] * batch_size
return pipeline(prompts, images)
return pipeline(prompts, images, **kwargs)
elif isinstance(pipeline, ImageToVideoPipeline):
images = [Image.open(IMAGE).convert("RGB")] * batch_size
return pipeline(images)
return pipeline(images, **kwargs)
else:
raise Exception("invalid pipeline")


def bench_pipeline(pipeline: Pipeline, batch_size=1, runs=1) -> BenchMetrics:
def bench_pipeline(pipeline: Pipeline, batch_size=1, runs=1, num_inference_steps=None) -> BenchMetrics:
inference_time = np.zeros(runs)
inference_time_per_output = np.zeros(runs)
max_mem_allocated = np.zeros(runs)
max_mem_reserved = np.zeros(runs)

kwargs = {"num_inference_steps": num_inference_steps} if num_inference_steps is not None else {}

for i in range(runs):
start = time()
output = call_pipeline(pipeline, batch_size)
output = call_pipeline(pipeline, batch_size, **kwargs)
if isinstance(output, tuple):
output = output[0]
assert len(output) == batch_size
Expand Down Expand Up @@ -94,13 +96,20 @@ def bench_pipeline(pipeline: Pipeline, batch_size=1, runs=1) -> BenchMetrics:
required=False,
help="the number of times to call the pipeline",
)
parser.add_argument(
"--num_inference_steps",
type=int,
default=None,
required=False,
help="the number of inference steps to run for the pipeline",
)
parser.add_argument(
"--batch_size", type=int, default=1, required=False, help="the size of a batch"
)

args = parser.parse_args()

print(f"{args.pipeline=} {args.model_id=} {args.runs=} {args.batch_size=}")
print(f"{args.pipeline=} {args.model_id=} {args.runs=} {args.batch_size=} {args.num_inference_steps=}")

start = time()
pipeline = load_pipeline(args.pipeline, args.model_id)
Expand All @@ -113,10 +122,10 @@ def bench_pipeline(pipeline: Pipeline, batch_size=1, runs=1) -> BenchMetrics:
# Collect pipeline warmup metrics if stable-fast is enabled
if os.getenv("SFAST", "").strip().lower() == "true":
warmups = 3
warmup_metrics = bench_pipeline(pipeline, args.batch_size, warmups)
warmup_metrics = bench_pipeline(pipeline, args.batch_size, warmups, args.num_inference_steps)

# Collect pipeline inference metrics
metrics = bench_pipeline(pipeline, args.batch_size, args.runs)
metrics = bench_pipeline(pipeline, args.batch_size, args.runs, args.num_inference_steps)

print("\n")
print("----AGGREGATE METRICS----")
Expand Down

0 comments on commit fc2e8f6

Please sign in to comment.