You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have an tensor with dtype torch.bfloat16, in kaggle v3-8, after the conv1 and conv2 operation the return type is torch.float32. Any way (environent varable or so) to convert the return type back to torch.bfloat16?
The text was updated successfully, but these errors were encountered:
@ghost Environment variables such as XLA_USE_BF16/XLA_DOWNCAST_BF16, is now deprecated and not needed anymore. The conv is lightly returning in float because of precision reasons, best to just do .to(torch.bfloat16) if the tensor is still on the xla device. If precision does not matter a lot, you can try converting the conv to bfloat16/modifying its forward pass to return bfloat16 tensor
❓ Questions and Help
I have an tensor with dtype torch.bfloat16, in kaggle v3-8, after the conv1 and conv2 operation the return type is torch.float32. Any way (environent varable or so) to convert the return type back to torch.bfloat16?
The text was updated successfully, but these errors were encountered: