Skip to content

Commit

Permalink
Merge branch 'master' into lyj/enable_mixtral
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Mar 13, 2024
2 parents 9f029be + d9e12d3 commit 0f6c619
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 10 deletions.
48 changes: 48 additions & 0 deletions .github/workflows/hpu-gaudi2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
name: hpu-gaudi2

on:
workflow_dispatch:
schedule:
- cron: "0 0 * * *"
pull_request:
paths:
- ".github/workflows/hpu-gaudi2.yml"

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

permissions:
contents: read
issues: write

jobs:
unit-tests:
# The type of runner that the job will run on
runs-on: [self-hosted, intel, gaudi2]
container:
image: vault.habana.ai/gaudi-docker/1.14.0/ubuntu22.04/habanalabs/pytorch-installer-2.1.1:latest
ports:
- 80
options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice

# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v3

- name: Check container state
run: |
ldd --version
hl-smi
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Install deepspeed
run: |
pip install .[dev]
ds_report
- name: Python environment
run: |
pip list
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#endif
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#ifndef __HIP_PLATFORM_HCC__
#ifndef __HIP_PLATFORM_AMD__
#include <mma.h>
#endif
#include <stdio.h>
Expand All @@ -33,7 +33,7 @@ class BlasContext {
std::cerr << message << std::endl;
throw std::runtime_error(message);
}
#ifndef __HIP_PLATFORM_HCC__
#ifndef __HIP_PLATFORM_AMD__
cublasSetMathMode(_handle, CUBLAS_TENSOR_OP_MATH);
#endif
}
Expand All @@ -55,7 +55,7 @@ class BlasContext {

enum class BlasType { FP32, FP16, BF16 };

#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_AMD__
rocblas_operation get_trans_op(bool do_trans)
{
return (do_trans) ? rocblas_operation_transpose : rocblas_operation_none;
Expand Down Expand Up @@ -99,7 +99,7 @@ int blas_gemm_ex(void* C,
const float* beta,
BlasType type)
{
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_AMD__
rocblas_operation_t transa_op = get_trans_op(transa);
rocblas_operation_t transb_op = get_trans_op(transb);

Expand Down Expand Up @@ -155,7 +155,7 @@ int blas_gemm_ex(void* C,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif

#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_AMD__
if (status != rocblas_status_success) {
#else
if (status != CUBLAS_STATUS_SUCCESS) {
Expand Down Expand Up @@ -190,7 +190,7 @@ int blas_strided_batched_gemm(void* C,
int batch,
BlasType type)
{
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_AMD__
rocblas_operation_t transa_op = get_trans_op(transa);
rocblas_operation_t transb_op = get_trans_op(transb);

Expand Down Expand Up @@ -257,7 +257,7 @@ int blas_strided_batched_gemm(void* C,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif

#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_AMD__
if (status != rocblas_status_success) {
#else
if (status != CUBLAS_STATUS_SUCCESS) {
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/inference/v2/ragged/csrc/ragged_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ copies.
*/
torch::Tensor allocate_fast_host_buffer(torch::Tensor device_mirror)
{
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_AMD__
auto options =
torch::TensorOptions().device(torch::kCPU).pinned_memory(true).dtype(device_mirror.dtype());
auto buffer = torch::empty(device_mirror.sizes(), options);
Expand Down
2 changes: 2 additions & 0 deletions deepspeed/module_inject/containers/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ def __init__(self, **kwargs):

# All model specific things should be defined here instead of the base class.
self.bigscience_bloom = True
self.triangular_masking = False

def create_module(self, config=None):
_config = config if config is not None else self.ds_model_config

self.module = DeepSpeedBloomInference(_config, mp_group=self.mp_group)
self.module.config.scale_attention = self.scale_attention
self.module.config.invert_mask = False
return self.module

def attention_qkv_mp(self, mp_replace, reversed_dim=False):
Expand Down
5 changes: 4 additions & 1 deletion deepspeed/ops/transformer/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class DeepSpeedInferenceConfig(TransformerConfig):
return_tuple: if True, returns the transformer output as a tuple, otherwise returns as a tensor
bigscience_bloom: This flag is added temporarily for supporting the BLOOM-176B model architecture.
use_triton: This flag is to enable triton kernels in inference or not.
invert_mask: If True, the attention mask is inverted when passed to attention block.
"""

def __init__(self,
Expand Down Expand Up @@ -80,7 +81,8 @@ def __init__(self,
use_triton=False,
triton_autotune=False,
num_kv=-1,
rope_theta=10000):
rope_theta=10000,
invert_mask=True):
super(DeepSpeedInferenceConfig,
self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads,
num_hidden_layers)
Expand Down Expand Up @@ -116,6 +118,7 @@ def __init__(self,
self.triton_autotune = triton_autotune
self.num_kv = num_kv
self.rope_theta = rope_theta
self.invert_mask = invert_mask

@classmethod
def from_dict(cls, json_object):
Expand Down
6 changes: 5 additions & 1 deletion deepspeed/ops/transformer/inference/ds_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,12 @@ def compute_attention(self, qkv_out, input_mask, layer_past, alibi):
if input_mask.dtype == torch.bool:
input_mask = input_mask.long()

# Invert input_mask per transformer implementation (eg, in BLOOM, it's already inverted)
if self.config.invert_mask:
input_mask = 1 - input_mask

attention_probs = self.softmax_func(attn_scores=attention_scores,
attn_mask=((1 - input_mask).to(target_dtype) * minus_inf),
attn_mask=input_mask.to(target_dtype) * minus_inf,
alibi=alibi,
triangular=(self.config.triangular_masking
and (attention_scores.shape[-2] > 1)),
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@ def bf16_required_version_check(accelerator_check=True):
cuda_version_available = CUDA_MAJOR >= 11
nccl_version_available = NCCL_MAJOR > 2 or (NCCL_MAJOR == 2 and NCCL_MINOR >= 10)
npu_available = get_accelerator().device_name() == 'npu'
hpu_available = get_accelerator().device_name() == 'hpu'

if torch_version_available and cuda_version_available and nccl_version_available and accelerator_pass:
return True
elif npu_available:
return True
elif hpu_available:
return True
else:
return False

Expand Down

0 comments on commit 0f6c619

Please sign in to comment.