Skip to content

Commit

Permalink
[XPU] support op builder from intel_extension_for_pytorch kernel path (
Browse files Browse the repository at this point in the history
…#5425)

#Motivation
From our next release, xpu DeepSpeed related kernels would be put into
intel_extension_for_pytorch. This PR is to add new op builders and use
kernel path from intel_extension_for_pytorch. More ops like MOE and WOQ
will be added.

---------

Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
3 people authored Jun 20, 2024
1 parent 0c979d6 commit d254d75
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 25 deletions.
46 changes: 22 additions & 24 deletions accelerator/xpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
import functools

import importlib
import inspect


class XPU_Accelerator(DeepSpeedAccelerator):

Expand All @@ -17,6 +20,7 @@ def __init__(self):
self._communication_backend_name = 'ccl'
self._compile_backend = "inductor"
self.aligned_tensors = []
self.class_dict = None

def is_synchronized_device(self):
return False
Expand Down Expand Up @@ -257,35 +261,29 @@ def on_accelerator(self, tensor):
else:
return False

def _lazy_init_class_dict(self):
if self.class_dict:
return

op_builder_module = importlib.import_module(self.op_builder_dir())

# get op builder class from op_builder/xpu/__init__.py
self.class_dict = {}
for class_name, class_obj in inspect.getmembers(op_builder_module, inspect.isclass):
self.class_dict[class_name] = class_obj

# create an instance of op builder and return, name specified by class_name
def create_op_builder(self, op_name):
builder_class = self.get_op_builder(op_name)
if builder_class != None:
return builder_class()
return None
def create_op_builder(self, class_name):
builder_class = self.get_op_builder(class_name)
return builder_class()

# return an op builder class, name specified by class_name
def get_op_builder(self, class_name):
try:
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
# if successful this also means we're doing a local install and not JIT compile path
from op_builder import __deepspeed__ # noqa: F401 # type: ignore
from op_builder.xpu import CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, AsyncIOBuilder, PackbitsBuilder
except ImportError:
from deepspeed.ops.op_builder.xpu import CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, AsyncIOBuilder, PackbitsBuilder

if class_name == "AsyncIOBuilder":
return AsyncIOBuilder
elif class_name == "CPUAdagradBuilder":
return CPUAdagradBuilder
elif class_name == "CPUAdamBuilder":
return CPUAdamBuilder
elif class_name == "FusedAdamBuilder":
return FusedAdamBuilder
elif class_name == "PackbitsBuilder":
return PackbitsBuilder
self._lazy_init_class_dict()
if class_name in self.class_dict:
return self.class_dict[class_name]
else:
return None
return self.class_dict['NotImplementedBuilder']

def build_extension(self):
try:
Expand Down
3 changes: 2 additions & 1 deletion deepspeed/comm/ccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@

import torch
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import NotImplementedBuilder
from .reduce_op import ReduceOp
from .torch import TorchBackend


def build_ccl_op():
builder = get_accelerator().create_op_builder("CCLCommBuilder")
if builder is None:
if builder is None or NotImplementedBuilder:
return None
ccl_cpp_module = builder.load()
print(f'DeepSpeed {builder.absolute_name()} built successfully')
Expand Down
3 changes: 3 additions & 0 deletions op_builder/xpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@
from .cpu_adagrad import CPUAdagradBuilder
from .fused_adam import FusedAdamBuilder
from .async_io import AsyncIOBuilder
from .inference import InferenceBuilder
from .flash_attn import FlashAttentionBuilder
from .no_impl import NotImplementedBuilder
from .packbits import PackbitsBuilder
53 changes: 53 additions & 0 deletions op_builder/xpu/flash_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
from .builder import SYCLOpBuilder


class FlashAttentionBuilderObject():

def __init__(self):
pass

# general functions
def flash_attn_func_v2(self, q, k, v, dropout_p, softmax_scale, is_causal):
try:
import torch
import intel_extension_for_pytorch # noqa
return torch.nn.functional.scaled_dot_product_attention(q,
k,
v,
dropout_p=dropout_p,
is_causal=is_causal,
scale=softmax_scale)
except ImportError:
raise ImportError(
"Please install pytorch and intel_extension_for_pytorch to include scaled dot product attention.")


class FlashAttentionBuilder(SYCLOpBuilder):
BUILD_VAR = "DS_BUILD_FlashAttention"
NAME = "flash_attn"

def __init__(self, name=None):
name = self.NAME if name is None else name
super().__init__(name=name)

def absolute_name(self):
return f'deepspeed.ops.{self.NAME}_op'

def sources(self):
return

def include_paths(self):
return []

def extra_ldflags(self):
return []

def cxx_args(self):
return []

def load(self):
return FlashAttentionBuilderObject()
36 changes: 36 additions & 0 deletions op_builder/xpu/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
from .builder import SYCLOpBuilder


class InferenceBuilder(SYCLOpBuilder):
BUILD_VAR = "DS_BUILD_TRANSFORMER_INFERENCE"
NAME = "transformer_inference"

def __init__(self, name=None):
name = self.NAME if name is None else name
super().__init__(name=name)

def absolute_name(self):
return f'deepspeed.ops.transformer.inference.{self.NAME}_op'

def sources(self):
return

def include_paths(self):
return []

def extra_ldflags(self):
return []

def cxx_args(self):
return []

def load(self):
try:
import intel_extension_for_pytorch.deepspeed
return intel_extension_for_pytorch.deepspeed.transformer_inference.transformer_inference
except ImportError:
raise ImportError("Please install intel-extension-for-pytorch >= 2.1.30 to include DeepSpeed kernels.")
33 changes: 33 additions & 0 deletions op_builder/xpu/no_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .builder import SYCLOpBuilder


class NotImplementedBuilder(SYCLOpBuilder):
BUILD_VAR = "DS_BUILD_NOT_IMPLEMENTED"
NAME = "deepspeed_not_implemented"

def __init__(self, name=None):
name = self.NAME if name is None else name
super().__init__(name=name)

def absolute_name(self):
return f'deepspeed.ops.{self.NAME}_op'

def load(self, verbose=True):
raise ValueError("This op had not been implemented on XPU backend.")

def sources(self):
return []

def cxx_args(self):
return []

def extra_ldflags(self):
return []

def include_paths(self):
return []

0 comments on commit d254d75

Please sign in to comment.