Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query device name from pytorch if only device index is given #500

Closed
wants to merge 2 commits into from

Commits on Jul 31, 2024

  1. Simplify code in lib.rs for Device

    Co-authored-by: Dmitry Rogozhkin <[email protected]>
    Narsil and dvrogozh committed Jul 31, 2024
    Configuration menu
    Copy the full SHA
    f141963 View commit details
    Browse the repository at this point in the history
  2. Query device name from pytorch if only device index is given

    Fixes: huggingface#499
    Fixes: huggingface/transformers#31941
    
    In some cases only device index is given on querying device. In this
    case both PyTorch and Safetensors were returning 'cuda:N' by default.
    This is causing runtime failures if user actually runs something on
    non-cuda device and does not have cuda at all. Recently this was
    addressed on PyTorch side by [1]: starting from PyTorch 2.5 calling
    'torch.device(N)' will return current device instead of cuda device.
    
    This commit is making similar change to Safetensors. If only device
    index is given, Safetensors will query and return device calling
    'torch.device(N)'. This change is backward compatible since this call
    would return 'cuda:N' on PyTorch <=2.4 which aligns with previous
    Safetensors behavior.
    
    See[1]: pytorch/pytorch#129119
    
    Signed-off-by: Dmitry Rogozhkin <[email protected]>
    dvrogozh committed Jul 31, 2024
    Configuration menu
    Copy the full SHA
    13f36cc View commit details
    Browse the repository at this point in the history