diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index b2b8a4a6776d9..79a05e4e3c1b3 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -615,9 +615,17 @@ def initialize_dummy_weights( # XLA device does not support torch.Generator() param.uniform_(low, high) continue + if current_platform.is_hpu(): + import habana_frameworks.torch.hpu.random as htrandom + generator = \ + htrandom.default_generators[ + 0].manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed( + seed) + #generator = torch.Generator(device=param.data.device) + #generator.manual_seed(seed) - generator = torch.Generator(device=param.data.device) - generator.manual_seed(seed) if torch.finfo(param.data.dtype).bits < 16: # uniform_ doesn't support < 16-bit datatypes (FP8) dtype = param.data.dtype