-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce dtype inference and improve dtype in ops.numpy.*
#938
Conversation
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## main #938 +/- ##
===========================================
+ Coverage 72.28% 83.69% +11.40%
===========================================
Files 319 320 +1
Lines 28879 29058 +179
Branches 5529 5579 +50
===========================================
+ Hits 20876 24320 +3444
+ Misses 6632 3195 -3437
- Partials 1371 1543 +172
Flags with carried forward coverage won't be shown. Click here to find out more.
☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
I think it's a good idea to use dtype=None
instead of dtype="float32"
in signatures.
We should add dtype checks in unit tests for all operations affected here, to check that we're in fact getting the same dtype across backends, including for array
. I think there might be ops where some backends will return float64 instead of float32. This will help us avoid inconsistencies.
@@ -348,6 +353,7 @@ def less_equal(x1, x2): | |||
def linspace( | |||
start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0 | |||
): | |||
dtype = dtype or config.floatx() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Things like this will deviate from the NumPy convention in the sense that NumPy tries to infer the dtype from argument dtypes. IMO defaulting to float32 is much better: simpler, more consistent. So I think we can go with it.
However if we're going to make this deviation, we should do it consistently, in all ops that infer output dtype from argument dtype, such as arange
.
The alternative is to stick to the NumPy dtype inference convention (but with float32 instead of float64).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should stick to the JAX dtype inference convention instead of NumPy, as it should be better suited for DL. What do you think?
We can consider reimplementing jnp.result_dtype
for all backends
https://github.com/google/jax/blob/2cba122bbe512f7927d165fdbb29108dcf0fe124/jax/_src/dtypes.py#L638
It may require some time if we decide to do so.
Also, we should start testing consistency between symbolic outputs and real op outputs. That's a of checks over all, so it would justify the introduce of a new TestCase for dtypes. |
I can add some new test cases in class NumpySymbolicDtypeTest(testing.TestCase):
...
class NumpyTensorDtypeTest(testing.TestCase):
... Is it good? However, It may take some time to implement the |
Yes, that sounds good!
We may be able to use a test parameterization to save time/code. We can parameterize the input dtype, for instance. But in some cases we may also be able to parameterize the op functions, for groups of ops that have similar arguments. |
ops.numpy.*
ops.numpy.*
Hi @fchollet I want to verify whether this PR is on the right track. I am attempting to implement a Keras Core version of If it is good, I will refactor some of |
Keras Core is becoming Keras 3, and we're switching development to the main repository! Please reopen this PR in the keras-team/keras repository. Unfortunately we aren't able to automatically transfer PRs (but we have transferred all issues). |
This PR unifies the default dtype behavior in
ops.numpy.*
and ensures that they respectbackend.floatx()
A subtle bug has been caught in
dropout_rnn_cell_test.py
:We should perform a custom mixed precision check because we can't initialize
cell
withdtype="mixed_float16"
inself.run_layer_test
.EDITED:
WIP:
backend.dtypes
functionalityops.numpy