Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Respects torch.device(0) new behavior without breaking backward compatibility #509

Merged
merged 2 commits into from
Aug 1, 2024

Conversation

Narsil
Copy link
Collaborator

@Narsil Narsil commented Jul 31, 2024

What does this PR do?

torch.device(0) means torch.device('cuda:0') in torch<=2.4.
Starting on torch==2.5, this behavior will be broken and mean something else (not exactly clear what the rules actually are).

To respect that behavior, we simply introduce a new anonymous device which simply keeps around the information that the user didn´t specify any device.

This is non breaking in safetensors since we're sending that information as-is to torch, therefore each version of torch will treat it accordingly to its own rules.
safetensors doesn't need the device information, it just needs to make sure that the input is valid.

Fixes # (issue) or description of the problem this PR solves.

Copy link

@dvrogozh dvrogozh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR does not fix #499. Fails with:

Traceback (most recent call last):
  File "/home/dvrogozh/examples/meta-llama/Meta-Llama-3-8B-Instruct/run.py", line 23, in <module>
    outputs = pipeline(
  File "/home/dvrogozh/git/huggingface/transformers/src/transformers/pipelines/text_generation.py", line 257, in __call__
    return super().__call__(Chat(text_inputs), **kwargs)
  File "/home/dvrogozh/git/huggingface/transformers/src/transformers/pipelines/base.py", line 1254, in __call__
    return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
  File "/home/dvrogozh/git/huggingface/transformers/src/transformers/pipelines/base.py", line 1261, in run_single
    model_outputs = self.forward(model_inputs, **forward_params)
  File "/home/dvrogozh/git/huggingface/transformers/src/transformers/pipelines/base.py", line 1161, in forward
    model_outputs = self._forward(model_inputs, **forward_params)
  File "/home/dvrogozh/git/huggingface/transformers/src/transformers/pipelines/text_generation.py", line 351, in _forward
    generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
  File "/home/dvrogozh/git/pytorch/pytorch/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/dvrogozh/git/huggingface/transformers/src/transformers/generation/utils.py", line 1989, in generate
    result = self._sample(
  File "/home/dvrogozh/git/huggingface/transformers/src/transformers/generation/utils.py", line 2932, in _sample
    outputs = self(**model_inputs, return_dict=True)
  File "/home/dvrogozh/git/pytorch/pytorch/torch/nn/modules/module.py", line 1735, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dvrogozh/git/pytorch/pytorch/torch/nn/modules/module.py", line 1746, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dvrogozh/git/huggingface/accelerate/src/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/dvrogozh/git/huggingface/transformers/src/transformers/models/llama/modeling_llama.py", line 1142, in forward
    outputs = self.model(
  File "/home/dvrogozh/git/pytorch/pytorch/torch/nn/modules/module.py", line 1735, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dvrogozh/git/pytorch/pytorch/torch/nn/modules/module.py", line 1746, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dvrogozh/git/huggingface/transformers/src/transformers/models/llama/modeling_llama.py", line 945, in forward
    layer_outputs = decoder_layer(
  File "/home/dvrogozh/git/pytorch/pytorch/torch/nn/modules/module.py", line 1735, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dvrogozh/git/pytorch/pytorch/torch/nn/modules/module.py", line 1746, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dvrogozh/git/huggingface/accelerate/src/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/dvrogozh/git/huggingface/transformers/src/transformers/models/llama/modeling_llama.py", line 675, in forward
    hidden_states = self.input_layernorm(hidden_states)
  File "/home/dvrogozh/git/pytorch/pytorch/torch/nn/modules/module.py", line 1735, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dvrogozh/git/pytorch/pytorch/torch/nn/modules/module.py", line 1746, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dvrogozh/git/huggingface/accelerate/src/accelerate/hooks.py", line 164, in new_forward
    args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
  File "/home/dvrogozh/git/huggingface/accelerate/src/accelerate/hooks.py", line 335, in pre_forward
    value = self.weights_map[name]
  File "/home/dvrogozh/git/huggingface/accelerate/src/accelerate/utils/offload.py", line 118, in __getitem__
    return self.dataset[f"{self.prefix}{key}"]
  File "/home/dvrogozh/git/huggingface/accelerate/src/accelerate/utils/offload.py", line 171, in __getitem__
    tensor = f.get_tensor(weight_info.get("weight_name", key))
RuntimeError: Invalid device string: '0'

I did not debug further to find out where device string '0' is not actually accepted.

Stack on these versions:

@dvrogozh
Copy link

I debugged it further. Error comes from here:

tensor = tensor.call_method("to", (device,), Some(&kwargs))?;

And raised here on pytorch side: https://github.com/pytorch/pytorch/blob/ad9826208c4d8e171d0c0d57a14818ede4babad4/c10/core/Device.cpp#L126

To reproduce, I think it's enough to run something like this:

python3 -c 'import torch; print(torch.randn(2, 2).to("0"))'

We can raise this on pytorch side that to(0) does not work, but if I understand correctly it currently does not work regardless of backend, i.e. with CUDA that will be same issue. I afraid that this PR as is currently might introduce regression for CUDA if merged on all currently available versions of pytorch. And if to(0) will be implemented on pytorch side, it won't be available in previous versions. Based on that I think it makes sense to proceed with the solution I propose in #500 which queries fully qualified device name from CUDA on index and further passes it around. There should be no issue with full device name I believe.

@Narsil
Copy link
Collaborator Author

Narsil commented Aug 1, 2024

Yes, I fixed it.

I'm confused why torch is strict in that specific case, but I fixed it in anycase using a raw int instead of a string.

Copy link

@dvrogozh dvrogozh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, .to(N) instead of to("N") works. This PR works for me to address #499. I am fine with your implementation over #500.

@Narsil Narsil merged commit 74c4e16 into main Aug 1, 2024
11 checks passed
@Narsil Narsil deleted the respect_torch_25 branch August 1, 2024 15:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants