diff --git a/benchmarks/torchbench_model.py b/benchmarks/torchbench_model.py index fd6a15a5cfe..e9d0c85cdef 100644 --- a/benchmarks/torchbench_model.py +++ b/benchmarks/torchbench_model.py @@ -344,7 +344,8 @@ def should_initialize_on_xla(self): # 1. Models don't expect 'tpu' as their device. # 2. 'moco' initializes a ProcessGroup, i.e. the backend depends on # the given device - return self.is_accelerator_tpu() or self.model_name == "moco" + return self.is_accelerator_tpu() or (self.model_name == "moco" and + self.benchmark_experiment.xla) def is_inference(self): return self.benchmark_experiment.test == "eval"