Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPUAdam fp16 and bf16 support #5409

Merged
merged 37 commits into from
May 20, 2024
Merged

CPUAdam fp16 and bf16 support #5409

merged 37 commits into from
May 20, 2024

Conversation

BacharL
Copy link
Collaborator

@BacharL BacharL commented Apr 14, 2024

Hi.
Please review the following changes
I added support for BF16 to cpu adam. BF16, FP16 and float are supported at compilation time. the correct template is called at runtime according to input params dtype.

@BacharL BacharL marked this pull request as draft April 15, 2024 19:57
@BacharL BacharL force-pushed the hab_cpu_adam branch 2 times, most recently from a9d5b2c to 11ddda8 Compare April 17, 2024 11:58
@BacharL BacharL changed the title [SW-173858] CPUAdam fp16 and bf16 support CPUAdam fp16 and bf16 support Apr 17, 2024
@BacharL BacharL marked this pull request as ready for review May 2, 2024 07:06
@BacharL BacharL requested review from tjruwase and loadams as code owners May 2, 2024 07:06
csrc/includes/cpu_adagrad.h Outdated Show resolved Hide resolved
csrc/includes/cpu_adagrad.h Outdated Show resolved Hide resolved
op_builder/cpu/cpu_adam.py Outdated Show resolved Hide resolved
#endif

typedef HALF_DTYPE ds_half_precision_t;
Copy link
Contributor

@tjruwase tjruwase May 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BacharL, this amazing PR of yours has made me realize that ds_half_precision_t was not well thought out for optimizer offloading. It did not anticipate that device type could be anything other than fp16 and bf16. Consequently, we now require users to set a confusing compiler option: -DHALF_DTYPE=float.

In retrospect, it seems ds_device_precision_t is a better name, and compiler option could be -DDEVICE_DTYPE=float

Similarly, the half_precision bool variables scattered around the code is also redundant. Thanks for removing some of them in this PR.

@BacharL, what do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will fix the naming of these, however, I think ds_half_precision_t causes discrepancy in buffer handling.
Step_AVX will handle the buffer with on cpu with FP16 of BF16 as implemented in AVX. The cast to fp32 is done in simd.h SIMD_LOAD_XX macros.
But the remaining part of the buffer will be handled in Step_1 with dtype from the device, which may be __half or __bfloat16. This dtype may not provide bit exact results as CPU implementation. The cast to fp32 is done in Step_1 inside the main loop.
I think even AVX/Native C++ cast may not be compatible. So if we already have this issue, why do we need to add device dtypes here? we can limit this code to c10::Half or c10::BFloat16 only.

We can also remove more half_precision varaibles by passing _params and dev_params with correct type instead of float* and fix their usage inside cuda and avx code

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for raising questions about the discrepancies in this code base. I am aligned that the code is unduly complicated to the point of being bug prone. I think this occurred because of the uncontrolled evolution of the code. Hopefully, some measure of cleanliness can be restored through careful usage of templates, and torch data type support. In the meantime, I will answer your question below.

I think even AVX/Native C++ cast may not be compatible. So if we already have this issue, why do we need to add device dtypes here? we can limit this code to c10::Half or c10::BFloat16 only

I think adding device dtype is useful for making the data type conversions of offloading computation explicit.

  1. Forward/backward computation on the device using device_dtype (currently fp16, bf16, or fp32) with output of gradients in device_dtype.
  2. Conversion step of gradients from device_dtype to host_dtype which is input for subsequent optimizer computation on host.
  3. Optimizer step computation on host using host_dtype (currently fp32) and returning param of host_dtype.
  4. Conversion of (updated) host_dtype param into device_dtype param which is input for subsequent forward/backward computation on device.

@tjruwase
Copy link
Contributor

tjruwase commented May 4, 2024

@BacharL, thanks for this incredible improvement to the offloading optimizers and op builders. I left a few comments and questions, but overall looks good to me.

BacharL added 12 commits May 5, 2024 17:21
Change-Id: I30df727076f25bcb95c5c16bce960b38950c8eb1
Change-Id: I6edda27251ffd09d514d8bc0ee9f37b5101e9508
Change-Id: I0954147a55d3687cf1dadcf8e739d6ec968ffb79
Change-Id: I327cfedd529acec4371e1a091730ff45b9275363
Change-Id: Iae33f22a021a9c4c521e82981617617e8fad6f6e
Change-Id: I1de599a4398f3fe38a4056b9bfc0ca9fbf06f4aa
cpu adam will use dtype from input tensor
no need for HALF_DTYPE define during compilation

Change-Id: I069f994d5229e88e75d092e2236b5fbafd8db994
Change-Id: Iad82ee3f12d6f2ba460b6f013b557836404583c3
Change-Id: I75c715cdbaeab7ff89c1dab2b34e3340713ee650
Change-Id: I26064fd65c8708e6944c83ababe00147bfe62967
Change-Id: I8d77e52aba3072151e337b8453684fe3ee0f873d
Change-Id: Idd23dd7aeba7c9da80e1f3a1f3ec307c033b4c7c
BacharL added 5 commits May 8, 2024 12:24
Change-Id: Icb880493fc63bd784b1e299a4e06348f25b74544
Change-Id: I34a4a0f866cd6b1884202055866ffceb5a7d0da2
Change-Id: I91d9eb5c209bff59426fcfcc42a68376ee67d8e0
Change-Id: Ic93a3e630083c110bb70373a4e5f3b364986a436
Change-Id: I87ab075bc53a94a8cbfa4daddc11170b8ee13e95
@BacharL
Copy link
Collaborator Author

BacharL commented May 8, 2024

Added templated invoker to help selecting the implementation
The map stores function pointers to templated functions, the key is the type enum. At initialization all supported dtypes are templated and inserted into the map.
I didn't clean ds_adagrad_step_plus_copy and related code under __ENABLE_CUDA__ but also couldn't test it.

tjruwase and others added 11 commits May 8, 2024 11:43
This reverts commit f18eef3.
Change-Id: I99b134c5de26f7e2a227d47bd84d0a5070b786b5
Change-Id: Iec76ef86c2c5ebd17138f16f0d14b9b62138c7ee
Change-Id: I6205b2a8ca53cf09f1d6c6c6123e242b8fabcf56
Change-Id: I57c175f61644a146a83310ecfbfc86fc638cc440
Change-Id: Ie9feb6049e20693c330b2c65e3b4fa77f29adea6
Change-Id: I79bcd63552820182b20e1def8a2e0b2f4490f706
Change-Id: I144160a40a725a71a02cf5bdb54f93cb5abdca43
Change-Id: Id2efe0143d092eeb893b52f68ed55e9c90a1b3c6
BacharL and others added 3 commits May 19, 2024 12:04
Change-Id: Ibbeaab192f2bc1771bcdd40b4927290076c9ca81
Change-Id: I3c1ebedfa433553138cb7ddf04dac6bdcbc0c68c
@tjruwase tjruwase added this pull request to the merge queue May 20, 2024
Merged via the queue into microsoft:master with commit 69af361 May 20, 2024
13 checks passed
sfc-gh-reyazda pushed a commit to Snowflake-Labs/DeepSpeed that referenced this pull request Jun 10, 2024
Hi.
Please review the following changes
I added support for BF16 to cpu adam. BF16, FP16 and float are supported
at compilation time. the correct template is called at runtime according
to input params dtype.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants