Skip to content

Commit

Permalink
Fix batchsize in llama2 reference implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
arjunsuresh committed Jul 25, 2024
1 parent 7e32454 commit eb19e0c
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion script/app-mlperf-inference-mlcommons-python/customize.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def preprocess(i):
else:
env['CM_NUM_THREADS'] = env.get('CM_HOST_CPU_TOTAL_CORES', '1')

if env.get('CM_MLPERF_LOADGEN_MAX_BATCHSIZE','') != '' and not env.get('CM_MLPERF_MODEL_SKIP_BATCHING', False):
if env.get('CM_MLPERF_LOADGEN_MAX_BATCHSIZE','') != '' and str(env.get('CM_MLPERF_MODEL_SKIP_BATCHING', False)).lower() not in [ "true", "1", "yes"] :
env['CM_MLPERF_LOADGEN_EXTRA_OPTIONS'] += " --max-batchsize " + str(env['CM_MLPERF_LOADGEN_MAX_BATCHSIZE'])

if env.get('CM_MLPERF_LOADGEN_BATCH_SIZE','') != '':
Expand Down Expand Up @@ -318,6 +318,7 @@ def get_run_cmd_reference(os_info, env, scenario_extra_options, mode_extra_optio
cmd += f" --num-workers {env['CM_MLPERF_INFERENCE_NUM_WORKERS']}"

cmd = cmd.replace("--count", "--total-sample-count")
cmd = cmd.replace("--max-batchsize", "--batch-size")

elif "mixtral-8x7b" in env['CM_MODEL']:
env['RUN_DIR'] = os.path.join(env['CM_MLPERF_INFERENCE_SOURCE'], "language", "mixtral-8x7b")
Expand Down

0 comments on commit eb19e0c

Please sign in to comment.