Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Float32 and other precisions #196

Closed
yebai opened this issue Jul 14, 2024 · 10 comments
Closed

Support Float32 and other precisions #196

yebai opened this issue Jul 14, 2024 · 10 comments
Assignees
Labels
enhancement New feature or request high priority

Comments

@yebai
Copy link
Contributor

yebai commented Jul 14, 2024

The assumption of only supporting Float64 is too strong. Are there any good reasons?

I am labelling this as bugs to signal its priority.

@yebai yebai added the bug Something isn't working label Jul 14, 2024
@willtebbutt willtebbutt added enhancement New feature or request high priority and removed bug Something isn't working labels Jul 22, 2024
@willtebbutt
Copy link
Member

The reason is a lack of testing infrastructure which can properly handle Float32 -- with multiple precisions involved, it becomes a little bit more tricky to properly determine appropriate accuracy thresholds for finite differencing etc.

I've removed the bug label, in favour of enhancement + high priority, because bug isn't an accurate description of what's going on here.

@willtebbutt
Copy link
Member

@yebai and @Red-Portal discovered that we're going to need a rule to Core.Intrinsics.fptrunc is going to need a conversion rule + a rule -- see TuringLang/AdvancedVI.jl#71

@willtebbutt
Copy link
Member

willtebbutt commented Jul 22, 2024

@yebai plan for this below:

We plan to support testing functions involving uniform-precision data. For example,

f(x::Float64, y::Float64)
f(x::Float32, y::Float32)

should be fine, but we do not intend to add support for functions containing a range of precisions. For example,

f(x::Float32, y::Float64)

would not be testable.

In order to support this, we would need to:

  1. produce some functionality which checks that only a single precision has been used in the arguments to a test case,
  2. add functionality to choose the finite differencing step size and the various accuracy thresholds, based on whether Float16 / Float32 / Float64 is used.

Once this is in place, we will need to

  1. add tests for Float32 all over the place and fix any problems which appear, and
  2. add support for specific functions, such as Core.Intrinsics.fptrunc, which we do not currently support.

@yebai
Copy link
Contributor Author

yebai commented Jul 22, 2024

That sounds good to me; maybe we could promote low-precision types to the highest ones in arguments as a default?

@willtebbutt
Copy link
Member

willtebbutt commented Jul 22, 2024

I think I'd rather just provide an informative error message, in order to avoid users thinking they're testing one thing, but actually testing another -- I can easily imagine a situation arising in which a user accidentally has a Float64 somewhere in their input, and then the whole thing gets silently promoted to Float64, and then they're testing something other than what they meant to test.

@willtebbutt
Copy link
Member

willtebbutt commented Nov 26, 2024

@penelopeysm discovered in TuringLang/docs#559 (comment) that Mooncake's current rule for matrix-matrix multiplication in LuxLib doesn't successfully handle the case that the two input arrays contain numbers at different precisions.

It really shouldn't be too hard to handle this properly -- we would just need to re-write the current from_rrule implementations of the various variants of matrix-matrix multiplication found around here to have proper Mooncake rules which

  1. define is_primitive for arrays whose elements are subtypes of IEEEFloat, and
  2. define the rrule!s in such a way that ensures the correct element-type is adhered to.
  3. add a method of rrule!! which is a catch-all for all other element types, which always errors with some kind of sensible error message that users can make use of to know how to modify their code.

Note: this is also a great opportunity to ensure excellent performance in Lux.jl -- the current implementations of the rules involve more allocations than are really needed, because we do not increment the gradients in-place.

@willtebbutt willtebbutt self-assigned this Dec 9, 2024
@willtebbutt
Copy link
Member

willtebbutt commented Dec 11, 2024

I'm marking this as done, because I believe #414 basically solves it . We can re-open if there is a reason to.

@yebai
Copy link
Contributor Author

yebai commented Dec 11, 2024

It really shouldn't be too hard to handle this properly -- we would just need to re-write the current from_rrule implementations of the various variants of matrix-matrix multiplication found around here to have proper Mooncake rules which

  1. define is_primitive for arrays whose elements are subtypes of IEEEFloat, and
  2. define the rrule!s in such a way that ensures the correct element-type is adhered to.
  3. add a method of rrule!! which is a catch-all for all other element types, which always errors with some kind of sensible error message that users can make use of to know how to modify their code.

Has this been addressed in previous PRs?

@willtebbutt
Copy link
Member

It seems to be fine for Float32 and Float64

@yebai
Copy link
Contributor Author

yebai commented Dec 11, 2024

Mooncake's current rule for matrix-matrix multiplication in LuxLib doesn't successfully handle the case that the two input arrays contain numbers at different precisions.

Looking at the code again, it doesn't handle scenarios where input arrays have different element types, which is what @penelopeysm discovered while working with Bayesian NNs. Did I miss something here?

EDIT: updated my comment above; see also, TuringLang/docs#559 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request high priority
Projects
None yet
Development

No branches or pull requests

2 participants