-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CPUAdam fp16 and bf16 support #5409
Conversation
a9d5b2c
to
11ddda8
Compare
csrc/includes/cpu_adam.h
Outdated
#endif | ||
|
||
typedef HALF_DTYPE ds_half_precision_t; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BacharL, this amazing PR of yours has made me realize that ds_half_precision_t
was not well thought out for optimizer offloading. It did not anticipate that device type could be anything other than fp16
and bf16
. Consequently, we now require users to set a confusing compiler option: -DHALF_DTYPE=float
.
In retrospect, it seems ds_device_precision_t
is a better name, and compiler option could be -DDEVICE_DTYPE=float
Similarly, the half_precision
bool variables scattered around the code is also redundant. Thanks for removing some of them in this PR.
@BacharL, what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will fix the naming of these, however, I think ds_half_precision_t causes discrepancy in buffer handling.
Step_AVX will handle the buffer with on cpu with FP16 of BF16 as implemented in AVX. The cast to fp32 is done in simd.h SIMD_LOAD_XX macros.
But the remaining part of the buffer will be handled in Step_1 with dtype from the device, which may be __half or __bfloat16. This dtype may not provide bit exact results as CPU implementation. The cast to fp32 is done in Step_1 inside the main loop.
I think even AVX/Native C++ cast may not be compatible. So if we already have this issue, why do we need to add device dtypes here? we can limit this code to c10::Half or c10::BFloat16 only.
We can also remove more half_precision varaibles by passing _params and dev_params with correct type instead of float* and fix their usage inside cuda and avx code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for raising questions about the discrepancies in this code base. I am aligned that the code is unduly complicated to the point of being bug prone. I think this occurred because of the uncontrolled evolution of the code. Hopefully, some measure of cleanliness can be restored through careful usage of templates, and torch data type support. In the meantime, I will answer your question below.
I think even AVX/Native C++ cast may not be compatible. So if we already have this issue, why do we need to add device dtypes here? we can limit this code to c10::Half or c10::BFloat16 only
I think adding device dtype is useful for making the data type conversions of offloading computation explicit.
- Forward/backward computation on the device using device_dtype (currently fp16, bf16, or fp32) with output of gradients in device_dtype.
- Conversion step of gradients from device_dtype to host_dtype which is input for subsequent optimizer computation on host.
- Optimizer step computation on host using host_dtype (currently fp32) and returning param of host_dtype.
- Conversion of (updated) host_dtype param into device_dtype param which is input for subsequent forward/backward computation on device.
@BacharL, thanks for this incredible improvement to the offloading optimizers and op builders. I left a few comments and questions, but overall looks good to me. |
Change-Id: I30df727076f25bcb95c5c16bce960b38950c8eb1
Change-Id: I6edda27251ffd09d514d8bc0ee9f37b5101e9508
Change-Id: I0954147a55d3687cf1dadcf8e739d6ec968ffb79
Change-Id: I327cfedd529acec4371e1a091730ff45b9275363
Change-Id: Iae33f22a021a9c4c521e82981617617e8fad6f6e
Change-Id: I1de599a4398f3fe38a4056b9bfc0ca9fbf06f4aa
cpu adam will use dtype from input tensor no need for HALF_DTYPE define during compilation Change-Id: I069f994d5229e88e75d092e2236b5fbafd8db994
Change-Id: Iad82ee3f12d6f2ba460b6f013b557836404583c3
Change-Id: I75c715cdbaeab7ff89c1dab2b34e3340713ee650
Change-Id: I26064fd65c8708e6944c83ababe00147bfe62967
Change-Id: I8d77e52aba3072151e337b8453684fe3ee0f873d
Change-Id: Idd23dd7aeba7c9da80e1f3a1f3ec307c033b4c7c
Change-Id: Icb880493fc63bd784b1e299a4e06348f25b74544
Change-Id: I34a4a0f866cd6b1884202055866ffceb5a7d0da2
Change-Id: I91d9eb5c209bff59426fcfcc42a68376ee67d8e0
Change-Id: Ic93a3e630083c110bb70373a4e5f3b364986a436
Added templated invoker to help selecting the implementation |
This reverts commit f18eef3.
Change-Id: I99b134c5de26f7e2a227d47bd84d0a5070b786b5
Change-Id: Iec76ef86c2c5ebd17138f16f0d14b9b62138c7ee
Change-Id: I6205b2a8ca53cf09f1d6c6c6123e242b8fabcf56
Change-Id: I57c175f61644a146a83310ecfbfc86fc638cc440
Change-Id: Ie9feb6049e20693c330b2c65e3b4fa77f29adea6
Change-Id: I79bcd63552820182b20e1def8a2e0b2f4490f706
Change-Id: I144160a40a725a71a02cf5bdb54f93cb5abdca43
Change-Id: Id2efe0143d092eeb893b52f68ed55e9c90a1b3c6
Change-Id: Ic3867ed78a636e88b884f304b68e897650a47ddc
Change-Id: Ibbeaab192f2bc1771bcdd40b4927290076c9ca81
Change-Id: I3c1ebedfa433553138cb7ddf04dac6bdcbc0c68c
Hi. Please review the following changes I added support for BF16 to cpu adam. BF16, FP16 and float are supported at compilation time. the correct template is called at runtime according to input params dtype. --------- Co-authored-by: Olatunji Ruwase <[email protected]>
Hi.
Please review the following changes
I added support for BF16 to cpu adam. BF16, FP16 and float are supported at compilation time. the correct template is called at runtime according to input params dtype.