-
Notifications
You must be signed in to change notification settings - Fork 201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Respects torch.device(0) new behavior without breaking backward compatibility #509
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR does not fix #499. Fails with:
Traceback (most recent call last):
File "/home/dvrogozh/examples/meta-llama/Meta-Llama-3-8B-Instruct/run.py", line 23, in <module>
outputs = pipeline(
File "/home/dvrogozh/git/huggingface/transformers/src/transformers/pipelines/text_generation.py", line 257, in __call__
return super().__call__(Chat(text_inputs), **kwargs)
File "/home/dvrogozh/git/huggingface/transformers/src/transformers/pipelines/base.py", line 1254, in __call__
return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
File "/home/dvrogozh/git/huggingface/transformers/src/transformers/pipelines/base.py", line 1261, in run_single
model_outputs = self.forward(model_inputs, **forward_params)
File "/home/dvrogozh/git/huggingface/transformers/src/transformers/pipelines/base.py", line 1161, in forward
model_outputs = self._forward(model_inputs, **forward_params)
File "/home/dvrogozh/git/huggingface/transformers/src/transformers/pipelines/text_generation.py", line 351, in _forward
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
File "/home/dvrogozh/git/pytorch/pytorch/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/dvrogozh/git/huggingface/transformers/src/transformers/generation/utils.py", line 1989, in generate
result = self._sample(
File "/home/dvrogozh/git/huggingface/transformers/src/transformers/generation/utils.py", line 2932, in _sample
outputs = self(**model_inputs, return_dict=True)
File "/home/dvrogozh/git/pytorch/pytorch/torch/nn/modules/module.py", line 1735, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/dvrogozh/git/pytorch/pytorch/torch/nn/modules/module.py", line 1746, in _call_impl
return forward_call(*args, **kwargs)
File "/home/dvrogozh/git/huggingface/accelerate/src/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/dvrogozh/git/huggingface/transformers/src/transformers/models/llama/modeling_llama.py", line 1142, in forward
outputs = self.model(
File "/home/dvrogozh/git/pytorch/pytorch/torch/nn/modules/module.py", line 1735, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/dvrogozh/git/pytorch/pytorch/torch/nn/modules/module.py", line 1746, in _call_impl
return forward_call(*args, **kwargs)
File "/home/dvrogozh/git/huggingface/transformers/src/transformers/models/llama/modeling_llama.py", line 945, in forward
layer_outputs = decoder_layer(
File "/home/dvrogozh/git/pytorch/pytorch/torch/nn/modules/module.py", line 1735, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/dvrogozh/git/pytorch/pytorch/torch/nn/modules/module.py", line 1746, in _call_impl
return forward_call(*args, **kwargs)
File "/home/dvrogozh/git/huggingface/accelerate/src/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/dvrogozh/git/huggingface/transformers/src/transformers/models/llama/modeling_llama.py", line 675, in forward
hidden_states = self.input_layernorm(hidden_states)
File "/home/dvrogozh/git/pytorch/pytorch/torch/nn/modules/module.py", line 1735, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/dvrogozh/git/pytorch/pytorch/torch/nn/modules/module.py", line 1746, in _call_impl
return forward_call(*args, **kwargs)
File "/home/dvrogozh/git/huggingface/accelerate/src/accelerate/hooks.py", line 164, in new_forward
args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
File "/home/dvrogozh/git/huggingface/accelerate/src/accelerate/hooks.py", line 335, in pre_forward
value = self.weights_map[name]
File "/home/dvrogozh/git/huggingface/accelerate/src/accelerate/utils/offload.py", line 118, in __getitem__
return self.dataset[f"{self.prefix}{key}"]
File "/home/dvrogozh/git/huggingface/accelerate/src/accelerate/utils/offload.py", line 171, in __getitem__
tensor = f.get_tensor(weight_info.get("weight_name", key))
RuntimeError: Invalid device string: '0'
I did not debug further to find out where device string '0' is not actually accepted.
Stack on these versions:
I debugged it further. Error comes from here: safetensors/bindings/python/src/lib.rs Line 635 in 8d21261
And raised here on pytorch side: https://github.com/pytorch/pytorch/blob/ad9826208c4d8e171d0c0d57a14818ede4babad4/c10/core/Device.cpp#L126 To reproduce, I think it's enough to run something like this:
We can raise this on pytorch side that |
Yes, I fixed it. I'm confused why torch is strict in that specific case, but I fixed it in anycase using a raw int instead of a string. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this PR do?
torch.device(0)
meanstorch.device('cuda:0')
in torch<=2.4.Starting on torch==2.5, this behavior will be broken and mean something else (not exactly clear what the rules actually are).
To respect that behavior, we simply introduce a new
anonymous
device which simply keeps around the information that the user didn´t specify any device.This is non breaking in safetensors since we're sending that information as-is to torch, therefore each version of torch will treat it accordingly to its own rules.
safetensors
doesn't need the device information, it just needs to make sure that the input is valid.Fixes # (issue) or description of the problem this PR solves.