Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Resolve #19 by fixing broken import #20

Closed
wants to merge 1 commit into from

Conversation

EtienneChollet
Copy link

No description provided.

@balbasty
Copy link
Owner

I am not sure what to do here....

  1. My functions take multiple input tensors, so it seems we do need custom_fwd/custom_bwd, even though I only use pure PyTorch functions under the hood (see here)
  2. But I don't love having the device hardcoded like you currently have. What if cuda is available but we are running some piece code on the CPU? Is it going to crash?
  3. We anyway need a wrapper around custom_fwd/custom_bwd to support both the "old version" (which does not have the device_type argument) and the "new version" (which requires the device_type argument).
  4. Does it need we must define different Function classes for cpu and gpu? This sounds so weird.... Maybe I can defined a generic wrapper that takes a Function class without decorators and returns two different classes decorated with the cuda and cpu decorators?
  5. Anyway, I need tests for this mixed-precision stuff. I don't think I actually ever tried running the code in mixed-precision

@EtienneChollet do you want to help with this? (you don't have to!)

@balbasty
Copy link
Owner

I am closing this PR as the issue was fixed (differently) in #21 .

@balbasty balbasty closed this Sep 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants