Skip to content

Commit

Permalink
Fixing several bugs in the inference-api and the kernels (#1951)
Browse files Browse the repository at this point in the history
Co-authored-by: Jeff Rasley <[email protected]>
  • Loading branch information
RezaYazdaniAminabadi and jeffra authored May 24, 2022
1 parent b8ff482 commit 8164ea9
Show file tree
Hide file tree
Showing 14 changed files with 1,087 additions and 134 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/amd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,23 @@ jobs:
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
sudo apt-get update
sudo apt-get install -y libaio-dev
- name: Install transformers
run: |
git clone https://github.com/huggingface/transformers
cd transformers
# if needed switch to the last known good SHA until transformers@master is fixed
# git checkout 1cc453d33
git rev-parse --short HEAD
pip install .
# Runs a set of commands using the runners shell
- name: Install deepspeed
run: |
sudo /opt/conda/bin/pip install .[dev,1bit,autotuning]
#python -c "from deepspeed.env_report import cli_main; cli_main()"
ds_report
# Runs a set of commands using the runners shell
- name: Unit tests
run: |
Expand Down
9 changes: 9 additions & 0 deletions .github/workflows/nv-torch12-p40.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ jobs:
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Install transformers
run: |
git clone https://github.com/huggingface/transformers
cd transformers
# if needed switch to the last known good SHA until transformers@master is fixed
# git checkout 1cc453d33
git rev-parse --short HEAD
pip install .
- name: Install deepspeed
run: |
pip install .[dev,autotuning]
Expand Down
11 changes: 11 additions & 0 deletions .github/workflows/nv-torch18-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,21 @@ jobs:
pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Install transformers
run: |
git clone https://github.com/huggingface/transformers
cd transformers
# if needed switch to the last known good SHA until transformers@master is fixed
# git checkout 1cc453d33
git rev-parse --short HEAD
pip install .
- name: Install deepspeed
run: |
pip install .[dev,1bit,autotuning,sparse_attn]
ds_report
- name: Unit tests
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
Expand Down
8 changes: 6 additions & 2 deletions csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <cuda_profiler_api.h>
#endif

namespace cg = cooperative_groups;
namespace cg = cooperative_groups;

__global__ void apply_rotary_pos_emb(float* mixed_query,
Expand Down Expand Up @@ -153,7 +154,9 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
int lane = id & 0x1f;

unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned seq_index = head_id % seq_len;
unsigned offset = head_id * head_size;
unsigned k_offset = (seq_index + (head_id / seq_len) * MAX_OUT_TOKES) * head_size;

constexpr unsigned mask[32] = {
0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000, 0x10 | 0x10000,
Expand All @@ -171,7 +174,7 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
float inv_freq = (float)((lane % half_dim) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float k = (float)key_layer[k_offset + lane];
float rotary_sign = (lane > (half_dim - 1) ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
Expand All @@ -183,7 +186,7 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);

mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
key_layer[k_offset + lane] = (__half)k;

lane += WARP_SIZE;
}
Expand Down Expand Up @@ -237,6 +240,7 @@ template void launch_apply_rotary_pos_emb<__half>(__half*,
bool,
bool,
cudaStream_t);

/*
__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
Expand Down
45 changes: 26 additions & 19 deletions csrc/transformer/inference/csrc/gelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -317,12 +317,18 @@ __global__ void gptj_residual_add(float* input,
float4 out = output_cast[offset];
float4 res_vec = attn_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
float4 attn_bias = attnbias_cast[offset % intermediate_size];

data.x = data.x * mp_size + (out.x + res_vec.x + bias_data.x + attn_bias.x);
data.y = data.y * mp_size + (out.y + res_vec.y + bias_data.y + attn_bias.y);
data.z = data.z * mp_size + (out.z + res_vec.z + bias_data.z + attn_bias.z);
data.w = data.w * mp_size + (out.w + res_vec.w + bias_data.w + attn_bias.w);
if (attnbias) {
float4 attn_bias = attnbias_cast[offset % intermediate_size];
data.x += attn_bias.x;
data.y += attn_bias.y;
data.z += attn_bias.z;
data.w += attn_bias.w;
}
data.x = data.x * mp_size + (out.x + res_vec.x + bias_data.x);
data.y = data.y * mp_size + (out.y + res_vec.y + bias_data.y);
data.z = data.z * mp_size + (out.z + res_vec.z + bias_data.z);
data.w = data.w * mp_size + (out.w + res_vec.w + bias_data.w);

output_cast[offset] = data;
}
Expand Down Expand Up @@ -354,13 +360,11 @@ __global__ void gptj_residual_add(__half* input,
float2 res_vec = attn_cast[offset];

float2 bias_vec = bias_cast[offset % intermediate_size];
float2 attn_bias_vec = attnbias_cast[offset % intermediate_size];

__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* out_half = reinterpret_cast<__half2*>(&out_vec);
__half2* res_half = reinterpret_cast<__half2*>(&res_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
__half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec);

float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
Expand All @@ -373,18 +377,21 @@ __global__ void gptj_residual_add(__half* input,

float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);

float2 attn_low_bias = __half22float2(attnbias_half[0]);
float2 attn_high_bias = __half22float2(attnbias_half[1]);

low_data.x =
low_data.x * mp_size + (low_out.x + low_res.x + (low_bias.x + attn_low_bias.x));
low_data.y =
low_data.y * mp_size + (low_out.y + low_res.y + (low_bias.y + attn_low_bias.y));
high_data.x =
high_data.x * mp_size + (high_out.x + high_res.x + (high_bias.x + attn_high_bias.x));
high_data.y =
high_data.y * mp_size + (high_out.y + high_res.y + (high_bias.y + attn_high_bias.y));
if (attn_bias) {
float2 attn_bias_vec = attnbias_cast[offset % intermediate_size];
__half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec);
float2 attn_low_bias = __half22float2(attnbias_half[0]);
float2 attn_high_bias = __half22float2(attnbias_half[1]);
low_data.x += attn_low_bias.x;
low_data.y += attn_low_bias.y;
high_data.x += attn_high_bias.x;
high_data.y += attn_high_bias.y;
}

low_data.x = low_data.x * mp_size + (low_out.x + low_res.x + (low_bias.x));
low_data.y = low_data.y * mp_size + (low_out.y + low_res.y + (low_bias.y));
high_data.x = high_data.x * mp_size + (high_out.x + high_res.x + (high_bias.x));
high_data.y = high_data.y * mp_size + (high_out.y + high_res.y + (high_bias.y));

vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
Expand Down
Loading

0 comments on commit 8164ea9

Please sign in to comment.