diff --git a/src/plugins/intel_gpu/include/intel_gpu/runtime/debug_configuration.hpp b/src/plugins/intel_gpu/include/intel_gpu/runtime/debug_configuration.hpp index a020c5d1cd5ef6..a7a8ae1f229a72 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/runtime/debug_configuration.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/runtime/debug_configuration.hpp @@ -147,6 +147,7 @@ class debug_configuration { int use_kv_cache_compression; // Enable KV-cache compression int dynamic_quantize_group_size; // Enable Dynamic quantization for fully connected primitive by specified group size int disable_horizontal_fc_fusion; // Disable fc horizontal fusion + int disable_fc_swiglu_fusion; // Disable swiglu fusion to fc std::set dump_iteration; // Dump n-th execution of network. std::vector load_layers_raw_dump; // List of layers to load dumped raw binary and filenames static const debug_configuration *get_instance(); diff --git a/src/plugins/intel_gpu/include/intel_gpu/runtime/layout.hpp b/src/plugins/intel_gpu/include/intel_gpu/runtime/layout.hpp index ab5cb53454b768..cc753d10aea9cd 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/runtime/layout.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/runtime/layout.hpp @@ -50,6 +50,11 @@ struct data_type_traits { return et.is_quantized() && et.bitwidth() == 8; } + static bool is_i4_u4(data_types data_type) { + auto et = ov::element::Type(data_type); + return et.bitwidth() == 4; + } + static ov::element::Type max_type(ov::element::Type t1, ov::element::Type t2) { if (t1.bitwidth() < t2.bitwidth()) return t2; diff --git a/src/plugins/intel_gpu/src/graph/fully_connected.cpp b/src/plugins/intel_gpu/src/graph/fully_connected.cpp index bc1e3e2e82b3ca..3ebbde2454c46b 100644 --- a/src/plugins/intel_gpu/src/graph/fully_connected.cpp +++ b/src/plugins/intel_gpu/src/graph/fully_connected.cpp @@ -7,6 +7,7 @@ #include #include #include "utils.hpp" +#include "swiglu_inst.h" #include "matmul_shape_inference.hpp" @@ -171,14 +172,36 @@ std::vector fully_connected_inst::calc_output_layouts(fully_connected_no output_type = impl_param.get_output_element_type(); } - ov::op::v0::MatMul op; - op.set_transpose_b(true); + ov::op::v0::MatMul matmul_op; + matmul_op.set_transpose_b(true); std::vector input_shapes = { input_layout.get(), weights_layout.get() }; - std::vector output_shapes = ov::op::v0::shape_infer(&op, input_shapes); + std::vector output_shapes = ov::op::v0::shape_infer(&matmul_op, input_shapes); + bool has_swiglu = false; + auto& fused_prims = node.get_fused_primitives(); + for (auto f : fused_prims) { + if (f.is_type()) { + has_swiglu = true; + OPENVINO_ASSERT(fused_prims.size() == 1, "Other operation is fused in addition to swiglu!"); + } + } + if (has_swiglu) { + ov::op::internal::GLU swiglu_op; + OPENVINO_ASSERT(fused_prims.size() == 1); + OPENVINO_ASSERT(fused_prims[0].typed_desc()->glu_type == ov::op::internal::GLU::GluType::Swish); + swiglu_op.set_axis(fused_prims[0].typed_desc()->axis); + swiglu_op.set_split_lengths(fused_prims[0].typed_desc()->split_lengths); + swiglu_op.set_glu_type(fused_prims[0].typed_desc()->glu_type); + std::vector input_shapes = { + output_shapes[0], + ShapeType(ov::Shape({})), + ShapeType(ov::Shape{2}) + }; + output_shapes = shape_infer(&swiglu_op, input_shapes); + } bool is_static = input_layout.is_static() && weights_layout.is_static(); bool allow_new_shape_infer = impl_param.get_program().is_new_shape_infer(); diff --git a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_primitive_fusing.cpp b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_primitive_fusing.cpp index 60d1e8aa7e10b7..29b7cf58a19b54 100644 --- a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_primitive_fusing.cpp +++ b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_primitive_fusing.cpp @@ -1,7 +1,7 @@ // Copyright (C) 2018-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // - +#include "intel_gpu/runtime/debug_configuration.hpp" #include "program_helpers.h" #include "pass_manager.h" @@ -37,6 +37,7 @@ #include "strided_slice_inst.h" #include "cum_sum_inst.h" #include "embedding_bag_inst.h" +#include "swiglu_inst.h" #include "extract_image_patches_inst.h" #include "reduce_inst.h" #include "group_normalization_inst.h" @@ -56,6 +57,7 @@ using namespace cldnn; void prepare_primitive_fusing::run(program& p) { fuse_reorders(p); remove_redundant_reshape(p); + fuse_swiglu(p); fuse_bias(p); fuse_simple_primitives(p); fuse_constant_transposes(p); @@ -161,6 +163,46 @@ void prepare_primitive_fusing::fuse_reorders(program &p) { } } +void prepare_primitive_fusing::fuse_swiglu(program &p) { + GPU_DEBUG_GET_INSTANCE(debug_config); + bool disable_fc_swiglu_fusion = false; + GPU_DEBUG_IF(debug_config->disable_fc_swiglu_fusion == 1) + disable_fc_swiglu_fusion = true; + // Apply only for high performant GPU + if (disable_fc_swiglu_fusion || p.get_engine().get_device_info().execution_units_count < 128) + return; + // TODO: to support other glu types && other weight data types + auto itr = p.get_processing_order().begin(); + std::map>> fusing_history; + while (itr != p.get_processing_order().end()) { + auto node_itr = itr++; + auto& node = (*node_itr); + if (node->is_type()) { + if (!node->get_dependency(0).is_type()) + continue; + auto swiglu_prim = node->get_kernel_impl_params()->typed_desc(); + auto& fc_node = node->get_dependency(0); + if (node->get_dependencies().size() > 1) + continue; + if (!node->get_dependency(0).get_fused_primitives().empty()) + continue; + auto in_dt = fc_node.get_input_layout(0).data_type; + if (in_dt != data_types::f16) + continue; + auto wt_dt = fc_node.get_input_layout(1).data_type; + if (!data_type_traits::is_i4_u4(wt_dt)) + continue; + if (swiglu_prim->glu_type != ov::op::internal::GLU::GluType::Swish || + !(swiglu_prim->axis == -1 || swiglu_prim->axis == static_cast(node->get_output_layout(0).get_partial_shape().size()) - 1)) + continue; + GPU_DEBUG_TRACE_DETAIL << node->id() << " : fuse swiglu to " << fc_node.id() << std::endl; + GPU_DEBUG_TRACE_DETAIL << " - split axis : " << swiglu_prim->axis << std::endl; + GPU_DEBUG_TRACE_DETAIL << " - split length : " << swiglu_prim->split_lengths << std::endl; + p.fuse_nodes(fc_node, *node, &fusing_history); + } + } +} + void prepare_primitive_fusing::fuse_bias(program &p) { auto itr = p.get_processing_order().begin(); while (itr != p.get_processing_order().end()) { @@ -188,6 +230,17 @@ void prepare_primitive_fusing::fuse_bias(program &p) { if (!is_bias_add) continue; + for (auto& dep : eltw_node.get_dependencies()) { + auto& fused_prims = dep.first->get_fused_primitives(); + if (std::any_of(fused_prims.begin(), fused_prims.end(), [](const fused_primitive_desc& f_desc) { + return f_desc.is_type(); + })) { + GPU_DEBUG_TRACE_DETAIL << "Skip fusing " << eltw_node.id() << " to " << dep.first->id() << " because " + << dep.first->id() << " has fused swiglu." << std::endl; + continue; + } + } + auto is_3d_fully_connected = [](program_node& node) { if (!node.is_type()) return false; @@ -491,6 +544,13 @@ void prepare_primitive_fusing::fuse_simple_primitives(program &p) { }; auto fc_supports_fusings = [&](fully_connected_node& node) -> bool { + auto& fused_prims = node.get_fused_primitives(); + if (std::any_of(fused_prims.begin(), fused_prims.end(), [](const fused_primitive_desc& f_desc) { + return f_desc.is_type(); + })) { + GPU_DEBUG_TRACE_DETAIL << node.id() << " has fused swiglu. Skip fusing more primitives" << std::endl; + return false; + } if (lo.has_all_enabled_onednn_impls_optimization_attribute() && lo.get_preferred_impl_type(node, format::any /*dummy*/) == impl_types::onednn) { return true; diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/fully_connected.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/fully_connected.cpp index 04f691c2bd2ca9..110444c2c6255c 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/fully_connected.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/fully_connected.cpp @@ -132,15 +132,16 @@ struct fully_connected_impl : typed_primitive_impl_ocl { return layouts; }; - auto get_fc_output_layout = [primitive](const std::vector& input_layouts, const layout& output_layout) { + auto get_fc_output_layout = [primitive](const std::vector& input_layouts, const layout& output_layout, bool swiglu_fused) { auto updated_out_layout = output_layout; auto input0_pshape = input_layouts[0].get_partial_shape(); auto input1_pshape = input_layouts[1].get_partial_shape(); ov::PartialShape updated_out_pshape {input0_pshape[0], input1_pshape[0]}; + const auto output_feature_size = swiglu_fused ? input1_pshape[0] / 2 : input1_pshape[0]; if (primitive->input_size == 3) { - updated_out_pshape = { input0_pshape[0], input0_pshape[1], input1_pshape[0] }; + updated_out_pshape = { input0_pshape[0], input0_pshape[1], output_feature_size}; } updated_out_layout.set_partial_shape(updated_out_pshape); @@ -149,6 +150,13 @@ struct fully_connected_impl : typed_primitive_impl_ocl { bool allow_new_shape_infer = impl_param.get_program().is_new_shape_infer(); auto updated_impl_param = impl_param; + bool swiglu_fused = false; + if (updated_impl_param.fused_desc.size() > 0) { + for (const auto& f : updated_impl_param.fused_desc) { + if (f.is_type()) + swiglu_fused = true; + } + } const auto input_layouts = get_fc_input_layouts(impl_param.input_layouts, allow_new_shape_infer); for (size_t i = 0; i < input_layouts.size(); ++i) { @@ -156,7 +164,7 @@ struct fully_connected_impl : typed_primitive_impl_ocl { } updated_impl_param.weights_layout = input_layouts[1]; - updated_impl_param.output_layouts[0] = get_fc_output_layout(input_layouts, impl_param.get_output_layout()); + updated_impl_param.output_layouts[0] = get_fc_output_layout(input_layouts, impl_param.get_output_layout(), swiglu_fused); return updated_impl_param; } diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/kernel_selector_helper.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/kernel_selector_helper.cpp index 0a999a5a124d3b..102fef8f64d39e 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/kernel_selector_helper.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/kernel_selector_helper.cpp @@ -32,11 +32,13 @@ #include "intel_gpu/primitives/embedding_bag.hpp" #include "intel_gpu/primitives/extract_image_patches.hpp" +#include "swiglu_inst.h" #include "activation_inst.h" #include "eltwise_inst.h" #include "quantize_inst.h" #include "reorder_inst.h" +#include "kernel_selector/kernels/swiglu/swiglu_kernel_base.h" #include "kernel_selector/kernels/activation/activation_kernel_base.h" #include "kernel_selector/kernels/depth_to_space/depth_to_space_kernel_base.h" #include "kernel_selector/kernels/eltwise/eltwise_kernel_base.h" @@ -1009,7 +1011,12 @@ kernel_selector::activation_function get_kernel_selector_activation_param(activa } std::shared_ptr convert_fuse_params(std::shared_ptr p) { - if (p->type() == activation::type_id()) { + if (p->type() == swiglu::type_id()) { + auto casted = std::dynamic_pointer_cast(p); + auto axis = casted->_desc->axis; + auto split_length = casted->_desc->split_lengths; + return std::make_shared(axis, split_length); + } else if (p->type() == activation::type_id()) { auto casted = std::dynamic_pointer_cast(p); auto desc = casted->_desc; kernel_selector::base_activation_params p; diff --git a/src/plugins/intel_gpu/src/graph/include/pass_manager.h b/src/plugins/intel_gpu/src/graph/include/pass_manager.h index 9850c25a64ec5d..e4b05a6e2aaf9f 100644 --- a/src/plugins/intel_gpu/src/graph/include/pass_manager.h +++ b/src/plugins/intel_gpu/src/graph/include/pass_manager.h @@ -140,6 +140,7 @@ class prepare_primitive_fusing : public base_pass { private: void run(program& p) override; void fuse_bias(program &p); + void fuse_swiglu(program &p); void fuse_reorders(program& p); void fuse_simple_primitives(program &p); void fuse_constant_transposes(program &p); diff --git a/src/plugins/intel_gpu/src/graph/include/swiglu_inst.h b/src/plugins/intel_gpu/src/graph/include/swiglu_inst.h index 6a5ce08dc54bd2..755e9ab33c2db6 100644 --- a/src/plugins/intel_gpu/src/graph/include/swiglu_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/swiglu_inst.h @@ -10,6 +10,11 @@ namespace cldnn { +class SwigluFuseParams : public NodeFuseParams { +public: + SwigluFuseParams(std::shared_ptr desc) : NodeFuseParams(swiglu::type_id()), _desc(std::move(desc)) {} + std::shared_ptr _desc; +}; template <> struct typed_program_node : public typed_program_node_base { using parent = typed_program_node_base; @@ -19,6 +24,10 @@ struct typed_program_node : public typed_program_node_base { program_node& input(size_t index = 0) const { return get_dependency(index); } std::vector get_shape_infer_dependencies() const override { return {}; } + + std::shared_ptr get_fuse_params() const override { + return std::make_shared(typed_desc()); + } }; using swiglu_node = typed_program_node; diff --git a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp index b51c7825b5a8fa..e33437ac87e3e5 100644 --- a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp +++ b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp @@ -38,6 +38,7 @@ #include "gather_inst.h" #include "broadcast_inst.h" #include "dynamic_quantize_inst.h" +#include "swiglu_inst.h" #include "experimental_detectron_roi_feature_extractor_inst.hpp" #include "impls/registry/implementation_manager.hpp" #include "impls/registry/registry.hpp" @@ -2591,6 +2592,16 @@ bool primitive_inst::is_valid_fusion() const { } else { if (fd.is_type() || fd.is_type()) continue; + if (fd.is_type()) { + OPENVINO_ASSERT(_node->is_type() && _node->get_preferred_impl_type() == impl_types::ocl); + if (!_node->get_selected_impl()) + return false; + // TODO : support ref kernel too + if (_node->get_selected_impl()->get_kernel_name().find("fully_connected_gpu_bf_tiled") != std::string::npos) + return true; + else + return false; + } OPENVINO_THROW("[GPU] Unsupported fused operation in dynamic shape: type=", fd.desc->type_string(), ", id=", fd.desc->id); } diff --git a/src/plugins/intel_gpu/src/graph/program_node.cpp b/src/plugins/intel_gpu/src/graph/program_node.cpp index 201fa3a155caa9..5161887b79e57a 100644 --- a/src/plugins/intel_gpu/src/graph/program_node.cpp +++ b/src/plugins/intel_gpu/src/graph/program_node.cpp @@ -10,6 +10,7 @@ #include "activation_inst.h" #include "reorder_inst.h" #include "quantize_inst.h" +#include "swiglu_inst.h" #include "intel_gpu/runtime/debug_configuration.hpp" #ifdef ENABLE_ONEDNN_FOR_GPU #include "convolution_inst.h" @@ -770,6 +771,15 @@ void program_node::save(cldnn::BinaryOutputBuffer& ob) const { ob << casted->_out_hi; ob << casted->_out_scale; ob << casted->_out_shift; + } else if (f_desc.f_param->type() == swiglu::type_id()) { + auto casted = std::dynamic_pointer_cast(f_desc.f_param); + if (get_program().has_node(casted->_desc->id)) { + ob << true; + ob << casted->_desc->id; + } else { + ob << false; + ob << casted->_desc; + } } ob << f_desc.deps.size(); @@ -975,6 +985,18 @@ void program_node::load(cldnn::BinaryInputBuffer& ib) { need_pre_shift, need_clamp, need_min_clamp, need_max_clamp, per_tensor_input_range, per_tensor_input_scale, per_tensor_input_shift, per_tensor_output_range, per_tensor_output_scale, per_tensor_output_shift, in_lo, in_hi, in_scale, in_shift, out_lo, out_hi, out_scale, out_shift); + } else if (f_param_type == swiglu::type_id()) { + ib >> exist_prim; + std::shared_ptr param_desc; + if (exist_prim) { + primitive_id desc_id; + ib >> desc_id; + param_desc = std::dynamic_pointer_cast(get_program().get_node_ptr(desc_id)->desc); + } else { + ib >> param_desc; + } + f_desc.f_param = std::make_shared(param_desc); + } else { f_desc.f_param = std::make_shared(f_param_type); } diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/fully_connected_gpu_bf_tiled.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/fully_connected_gpu_bf_tiled.cl index 201b59c160cf27..9e79e18dbbee7d 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/fully_connected_gpu_bf_tiled.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/fully_connected_gpu_bf_tiled.cl @@ -95,6 +95,12 @@ KERNEL(quantize_input)( # error "fully_connected_gpu_bf_tiled.cl - TILE_K must be one of {1, 2, 4}" # endif #endif + +#ifdef SWIGLU_LENGTH +# if OUTER_OFM != 2 +# error "fully_connected_gpu_bf_tiled.cl - outer_ofm should be 2 when swiglu is fused" +# endif +#endif #if TILE_K == 4 && COMPRESSED_WEIGHTS_INT4 && FILTER_LAYOUT_OS_IS_YX_OSV32_ISV2 // Data stored in memory : f0k0k1|f16k0k1|f0k2k3|f16k2k3 // => unpack as f0k0k1|f0k2k3|f16k0k1|f16k2k3 so that the weight access order is preserved @@ -210,14 +216,27 @@ inline void FUNC(fc_bf_tiled_kernel_default)( // full dispatch pipeline. uint feature_mini_block = gid % DISPATCH_FSV; uint batch_mini_block = gid / DISPATCH_FSV % DISPATCH_BSV; + #ifdef SWIGLU_LENGTH + uint feature_mega_block = gid / (DISPATCH_FSV * DISPATCH_BSV) % (CEIL_DIV(TILE_OUT_F_NUM, TILE_OFM * SIMD) / DISPATCH_FSV); + uint batch_mega_block = gid / (DISPATCH_FSV * DISPATCH_BSV * CEIL_DIV(TILE_OUT_F_NUM, TILE_OFM * SIMD) / DISPATCH_FSV); + #else uint feature_mega_block = gid / (DISPATCH_FSV * DISPATCH_BSV) % (CEIL_DIV(TILE_OUT_F_NUM, OUTER_OFM * TILE_OFM * SIMD) / DISPATCH_FSV); uint batch_mega_block = gid / (DISPATCH_FSV * DISPATCH_BSV * CEIL_DIV(TILE_OUT_F_NUM, OUTER_OFM * TILE_OFM * SIMD) / DISPATCH_FSV); + #endif #if USE_SLM + #ifdef SWIGLU_LENGTH + uint out_f = gid * (TILE_OFM * SIMD); + #else uint out_f = gid * (OUTER_OFM * TILE_OFM * SIMD); + #endif uint out_b = LWS_BATCHES * TILE_B * (uint)get_group_id(2) + local_id * TILE_B; #else + #ifdef SWIGLU_LENGTH + uint out_f = (feature_mega_block * DISPATCH_FSV + feature_mini_block) * (TILE_OFM * SIMD); + #else uint out_f = (feature_mega_block * DISPATCH_FSV + feature_mini_block) * (OUTER_OFM * TILE_OFM * SIMD); + #endif uint out_b = ((batch_mega_block * DISPATCH_BSV + batch_mini_block) * TILE_B); #endif @@ -299,9 +318,20 @@ inline void FUNC(fc_bf_tiled_kernel_default)( ACCUMULATOR_TYPE* d_zps = (ACCUMULATOR_TYPE*)(&d_zp); #endif + ACTIVATION_VEC_TYPE activated[TILE_B] = { }; #if OUTER_OFM > 1 uint input_offset_init = input_offset; - unroll_for (uint oi = 0; oi < OUTER_OFM; ++oi) { + uint weights_offset_init = weights_offset; + uint out_f_init = out_f; + __attribute__((opencl_unroll_hint(1))) + for (uint oi = 0; oi < OUTER_OFM; ++oi) { + input_offset = input_offset_init; + #ifdef SWIGLU_LENGTH + weights_offset = weights_offset_init + oi * (FILTER_IFM_NUM / (TILE_K_OFM / TILE_K_OFM_PACKED) ) * SWIGLU_LENGTH; + out_f += SWIGLU_LENGTH * oi; + #else + out_f += TILE_OFM * SIMD * oi; + #endif #endif #if REALIGN_FP16_OFFSET @@ -669,14 +699,28 @@ inline void FUNC(fc_bf_tiled_kernel_default)( #endif // MAIN_LOOP_ELEMENTS_COUNT % (TILE_IFM * SIMD) != 0 // ===================================================================================================================================== // Post-processing: bias, activation, fused-ops - ACTIVATION_VEC_TYPE activated[TILE_B] = { }; - for (uint bi = 0; bi < TILE_B; ++bi) { + unroll_for (uint bi = 0; bi < TILE_B; ++bi) { + #ifdef SWIGLU_LENGTH + if (oi == 0) { + // swish + activated[bi] = TO_ACTIVATION_VEC_TYPE(acc[bi]); + activated[bi] /= (ACCUMULATOR_VAL_ONE + native_exp(-(ACCUMULATOR_VAL_ONE * activated[bi]))); + } else { + activated[bi] *= TO_ACTIVATION_VEC_TYPE(acc[bi]); + } + #else activated[bi] = TO_ACTIVATION_VEC_TYPE(acc[bi]); + #endif #if OUTER_OFM > 1 acc[bi] = 0; #endif } +#if OUTER_OFM > 1 && defined(SWIGLU_LENGTH) + } + out_f = out_f_init; +#endif + #if BIAS_TERM #if TILE_OUT_F_NUM % (OUTER_OFM * TILE_OFM * SIMD) == 0 BIAS_VEC_TYPE bias = BIAS_BLOCK_READ(biases, out_f); @@ -746,9 +790,7 @@ inline void FUNC(fc_bf_tiled_kernel_default)( output_offset += TILE_OUT_B_PITCH - TILE_OFM * SIMD; } } -#if OUTER_OFM > 1 - out_f += TILE_OFM * SIMD; - input_offset = input_offset_init; +#if OUTER_OFM > 1 && !defined(SWIGLU_LENGTH) } #endif // ===================================================================================================================================== @@ -816,8 +858,14 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)( // full dispatch pipeline. uint feature_mini_block = gid % DISPATCH_FSV; uint batch_mini_block = gid / DISPATCH_FSV % DISPATCH_BSV; + #ifdef SWIGLU_LENGTH uint feature_mega_block = gid / (DISPATCH_FSV * DISPATCH_BSV) % (CEIL_DIV(TILE_OUT_F_NUM, TILE_OFM * SIMD) / DISPATCH_FSV); uint batch_mega_block = gid / (DISPATCH_FSV * DISPATCH_BSV * CEIL_DIV(TILE_OUT_F_NUM, TILE_OFM * SIMD) / DISPATCH_FSV); + #else + uint feature_mega_block = gid / (DISPATCH_FSV * DISPATCH_BSV) % (CEIL_DIV(TILE_OUT_F_NUM, OUTER_OFM * TILE_OFM * SIMD) / DISPATCH_FSV); + uint batch_mega_block = gid / (DISPATCH_FSV * DISPATCH_BSV * CEIL_DIV(TILE_OUT_F_NUM, OUTER_OFM * TILE_OFM * SIMD) / DISPATCH_FSV); + #endif + FILTER_VEC_TYPE wei = 0; @@ -895,6 +943,22 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)( ACCUMULATOR_TYPE* d_zps = (ACCUMULATOR_TYPE*)(&d_zp); #endif + ACTIVATION_VEC_TYPE activated[TILE_B] = { }; +#if OUTER_OFM > 1 + uint input_offset_init = input_offset; + uint weights_offset_init = weights_offset; + uint out_f_init = out_f; + __attribute__((opencl_unroll_hint(1))) + for (uint oi = 0; oi < OUTER_OFM; ++oi) { + input_offset = input_offset_init; + #ifdef SWIGLU_LENGTH + weights_offset = weights_offset_init + oi * (FILTER_IFM_NUM / (TILE_K_OFM / TILE_K_OFM_PACKED) ) * SWIGLU_LENGTH; + out_f += SWIGLU_LENGTH * oi; + #else + out_f += TILE_OFM * SIMD * oi; + #endif +#endif + // ===================================================================================================================================== // Main computation loop const uint iterations = MAIN_LOOP_ELEMENTS_COUNT / TILE_IFM_ELEMENTS_SIZE; // TILE_IFM_ELEMENTS_SIZE : (TILE_IFM * SIMD) @@ -1164,10 +1228,26 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)( // ===================================================================================================================================== // Post-processing: bias, activation, fused-ops - ACTIVATION_VEC_TYPE activated[TILE_B] = { }; for (uint bi = 0; bi < TILE_B; ++bi) { + #ifdef SWIGLU_LENGTH + if (oi == 0) { + activated[bi] = TO_ACTIVATION_VEC_TYPE(acc[bi]); + activated[bi] /= (ACCUMULATOR_VAL_ONE + native_exp(-(ACCUMULATOR_VAL_ONE * activated[bi]))); + } else { + activated[bi] *= TO_ACTIVATION_VEC_TYPE(acc[bi]); + } + #else activated[bi] = TO_ACTIVATION_VEC_TYPE(acc[bi]); + #endif +#if OUTER_OFM > 1 + acc[bi] = 0; +#endif + } + +#if OUTER_OFM > 1 && defined(SWIGLU_LENGTH) } + out_f = out_f_init; +#endif #if BIAS_TERM #if TILE_OUT_F_NUM % (TILE_OFM * SIMD) == 0 @@ -1240,6 +1320,9 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)( output_offset += TILE_OUT_B_PITCH - TILE_OFM * SIMD; } } +#if OUTER_OFM > 1 && !defined(SWIGLU_LENGTH) + } +#endif // ===================================================================================================================================== } #endif diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/include/fully_connected_gpu_bf_tiled_common.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/include/fully_connected_gpu_bf_tiled_common.cl index ddffa87b202816..18d62f894f80f9 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/include/fully_connected_gpu_bf_tiled_common.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/include/fully_connected_gpu_bf_tiled_common.cl @@ -25,7 +25,6 @@ inline void (FUNC_NAME)( ) { uint gid = (uint)get_group_id(0); uint sglid = (uint)get_sub_group_local_id(); - // Dispatch as bs_fs_bsv_fsv, where bsv = DISPATCH_BSV and fsv = DISPATCH_FSV. // This allows more fine grained control over dispatch order than using work-groups and // avoids requirement of threads being available for whole work-group. @@ -33,10 +32,19 @@ inline void (FUNC_NAME)( // full dispatch pipeline. uint feature_mini_block = gid % DISPATCH_FSV; uint batch_mini_block = gid / DISPATCH_FSV % DISPATCH_BSV; + #ifdef SWIGLU_LENGTH + uint feature_mega_block = gid / (DISPATCH_FSV * DISPATCH_BSV) % (CEIL_DIV(TILE_OUT_F_NUM, TILE_OFM * SIMD) / DISPATCH_FSV); + uint batch_mega_block = gid / (DISPATCH_FSV * DISPATCH_BSV * CEIL_DIV(TILE_OUT_F_NUM, TILE_OFM * SIMD) / DISPATCH_FSV); + #else uint feature_mega_block = gid / (DISPATCH_FSV * DISPATCH_BSV) % (CEIL_DIV(TILE_OUT_F_NUM, OUTER_OFM * TILE_OFM * SIMD) / DISPATCH_FSV); uint batch_mega_block = gid / (DISPATCH_FSV * DISPATCH_BSV * CEIL_DIV(TILE_OUT_F_NUM, OUTER_OFM * TILE_OFM * SIMD) / DISPATCH_FSV); + #endif + #ifdef SWIGLU_LENGTH + uint out_f = (feature_mega_block * DISPATCH_FSV + feature_mini_block) * (TILE_OFM * SIMD); + #else uint out_f = (feature_mega_block * DISPATCH_FSV + feature_mini_block) * (OUTER_OFM * TILE_OFM * SIMD); + #endif uint out_b = ((batch_mega_block * DISPATCH_BSV + batch_mini_block) * FORCED_TILE_B); ACCUMULATOR_VEC_TYPE acc[FORCED_TILE_B] = { }; @@ -90,9 +98,19 @@ inline void (FUNC_NAME)( ACCUMULATOR_TYPE* d_zps = (ACCUMULATOR_TYPE*)(&d_zp); #endif + ACTIVATION_VEC_TYPE activated[FORCED_TILE_B] = { }; #if OUTER_OFM > 1 uint input_offset_init = input_offset; + uint weights_offset_init = weights_offset; + uint out_f_init = out_f; unroll_for (uint oi = 0; oi < OUTER_OFM; ++oi) { + input_offset = input_offset_init; + #ifdef SWIGLU_LENGTH + weights_offset = weights_offset_init + oi * (FILTER_IFM_NUM / (TILE_K_OFM / TILE_K_OFM_PACKED) ) * SWIGLU_LENGTH; + out_f += SWIGLU_LENGTH * oi; + #else + out_f += TILE_OFM * SIMD * oi; + #endif #endif #if REALIGN_FP16_OFFSET @@ -297,14 +315,27 @@ inline void (FUNC_NAME)( #endif // MAIN_LOOP_ELEMENTS_COUNT % (TILE_IFM * SIMD) != 0 // ===================================================================================================================================== // Post-processing: bias, activation, fused-ops - ACTIVATION_VEC_TYPE activated[FORCED_TILE_B] = { }; for (uint bi = 0; bi < FORCED_TILE_B; ++bi) { + #ifdef SWIGLU_LENGTH + if (oi == 0) { + activated[bi] = TO_ACTIVATION_VEC_TYPE(acc[bi]); + activated[bi] /= (ACCUMULATOR_VAL_ONE + native_exp(-(ACCUMULATOR_VAL_ONE * activated[bi]))); + } else { + activated[bi] *= TO_ACTIVATION_VEC_TYPE(acc[bi]); + } + #else activated[bi] = TO_ACTIVATION_VEC_TYPE(acc[bi]); + #endif #if OUTER_OFM > 1 acc[bi] = 0; #endif } +#if OUTER_OFM > 1 && defined(SWIGLU_LENGTH) + } + out_f = out_f_init; +#endif + #if BIAS_TERM #if TILE_OUT_F_NUM % (OUTER_OFM * TILE_OFM * SIMD) == 0 BIAS_VEC_TYPE bias = BIAS_BLOCK_READ(biases, out_f); @@ -396,9 +427,7 @@ inline void (FUNC_NAME)( output_offset += TILE_OUT_B_PITCH - TILE_OFM * SIMD; } } -#if OUTER_OFM > 1 - out_f += TILE_OFM * SIMD; - input_offset = input_offset_init; +#if OUTER_OFM > 1 && !defined(SWIGLU_LENGTH) } #endif // ===================================================================================================================================== diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp index 02304512637783..fa601d80b10cf9 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp @@ -4,6 +4,7 @@ #include "fully_connected_kernel_bf_tiled.h" #include "kernel_selector_utils.h" +#include "swiglu/swiglu_kernel_base.h" #include #include #include "common_types.h" @@ -163,7 +164,21 @@ static bool is_weight_small_kn(const fully_connected_params& params, size_t outp return output_f / 2 /*most frequently used tile_ofm*/ <= min_num_threads; } +static bool is_swiglu_fused(const fully_connected_params& params) { + bool swiglu_fused = false; + if (!params.fused_ops.empty()) { + for (auto p : params.fused_ops) { + if (p.GetType() == kernel_selector::KernelType::SWIGLU) + swiglu_fused = true; + } + } + if (swiglu_fused) + OPENVINO_ASSERT(params.fused_ops.size() == 1); + return swiglu_fused; +} static bool is_suitable_outer_ofm(const fully_connected_params& params, size_t output_f) { + if (is_swiglu_fused(params)) + return true; size_t min_num_threads = params.engineInfo.computeUnitsCount * simd; return (params.weights.OFM().v > params.weights.IFM().v * 6 && output_f / 8 /* tile_ofm=4 and outer_ofm=2 */ > min_num_threads * 1.5); @@ -406,6 +421,8 @@ FullyConnected_bf_tiled::GetAutoTuneParams(const fully_connected_params& params, while (max_tile_ofm * 2 * simd <= output_f && max_tile_ofm < 4) max_tile_ofm *= 2; + bool swiglu_fused = is_swiglu_fused(params); + if (params.weights.GetDType() == WeightsType::UINT4 || params.weights.GetDType() == WeightsType::INT4 || (is_weight_dyn_quantizable(params) && should_dynamic_quantize(params))) { // Only 4bit weight type is fully optimized to use SLM. In default kernel, SLM is not applied to 8bit weight. @@ -426,30 +443,39 @@ FullyConnected_bf_tiled::GetAutoTuneParams(const fully_connected_params& params, if (params.weights.GetLayout() == WeightsLayout::os_iyx_osv16) { return selector.Default(tune_params(1, 1, 4, 4, 1, 1, 1, EXE_MODE_DEFAULT)); } else if (params.weights.GetLayout() == WeightsLayout::os_is_yx_osv64_isv2) { - selector.Case(tune_params(1, 4, 4, 2, 2, 1, 1, EXE_MODE_DEFAULT)) - .Case(tune_params(1, 4, 4, 2, 1, 1, 1, EXE_MODE_DEFAULT)); + // Here : b1 static + if (swiglu_fused) { + return selector.Default(tune_params(1, 4, 4, 2, 2, 1, 1, EXE_MODE_DEFAULT)); + } else { + selector.Case(tune_params(1, 4, 4, 2, 2, 1, 1, EXE_MODE_DEFAULT)) + .Case(tune_params(1, 4, 4, 2, 1, 1, 1, EXE_MODE_DEFAULT)); + } } else { - return selector.Default(tune_params(1, 2, 4, 2, 1, 1, 1, EXE_MODE_DEFAULT)); + if (swiglu_fused) { + return selector.Default(tune_params(1, 2, 4, 2, 2, 1, 1, EXE_MODE_DEFAULT)); + } else { + return selector.Default(tune_params(1, 2, 4, 2, 1, 1, 1, EXE_MODE_DEFAULT)); + } } } } else { // Try to use SLM kernels if possible + unsigned int forced_outer_ofm = swiglu_fused ? 2 : 1; if (preferred_kernel_type != KernelType::DEFAULT) { if (params.is_shape_agnostic && !should_dynamic_quantize(params)) { - selector.Case(tune_params(16, 2, 2, 4, 1, 1, 1, EXE_MODE_DEFAULT, KernelType::SLM)) - .Case(tune_params(16, 2, 1, 4, 1, 1, 1, EXE_MODE_DEFAULT, KernelType::SLM)); + selector.Case(tune_params(16, 2, 2, 4, forced_outer_ofm, 1, 1, EXE_MODE_DEFAULT, KernelType::SLM)) + .Case(tune_params(16, 2, 1, 4, forced_outer_ofm, 1, 1, EXE_MODE_DEFAULT, KernelType::SLM)); } - - selector.Case(tune_params(8, 2, 2, 4, 1, 1, 1, EXE_MODE_DEFAULT, KernelType::SLM)) - .Case(tune_params(8, 2, 1, 4, 1, 1, 1, EXE_MODE_DEFAULT, KernelType::SLM)); + selector.Case(tune_params(8, 2, 2, 4, forced_outer_ofm, 1, 1, EXE_MODE_DEFAULT, KernelType::SLM)) + .Case(tune_params(8, 2, 1, 4, forced_outer_ofm, 1, 1, EXE_MODE_DEFAULT, KernelType::SLM)); } if (params.weights.GetLayout() == WeightsLayout::os_iyx_osv16) - return selector.Default(tune_params(8, 1, 1, 4, 1, 1, 1, EXE_MODE_DEFAULT)); + return selector.Default(tune_params(8, 1, 1, 4, forced_outer_ofm, 1, 1, EXE_MODE_DEFAULT)); else if (params.weights.GetLayout() == WeightsLayout::os_is_yx_osv64_isv2) - return selector.Default(tune_params(8, 4, 1, 2, 1, 1, 1, EXE_MODE_DEFAULT)); + return selector.Default(tune_params(8, 4, 1, 2, forced_outer_ofm, 1, 1, EXE_MODE_DEFAULT)); else - return selector.Default(tune_params(8, 2, 1, 4, 1, 1, 1, EXE_MODE_DEFAULT)); + return selector.Default(tune_params(8, 2, 1, 4, forced_outer_ofm, 1, 1, EXE_MODE_DEFAULT)); } } else if (params.compressed && params.engineInfo.supports_immad) { return selector.Default(tune_params(1, 1, 1, 4, 1, 1, 1, EXE_MODE_DEFAULT)); @@ -526,8 +552,12 @@ FullyConnected_bf_tiled::SetDefault(const fully_connected_params& params, int au kernel_type = kernel_number == 0 ? KernelType::DEFAULT : KernelType::SLM; auto tparams = GetAutoTuneParams(params, kernel_type, autoTuneIndex); + std::pair threads; + if (is_swiglu_fused(params)) + threads = get_output_aligned_bf_size(params, true, tparams.tile_b, tparams.tile_ofm * simd); + else + threads = get_output_aligned_bf_size(params, true, tparams.tile_b, tparams.tile_ofm * tparams.outer_ofm * simd); - auto threads = get_output_aligned_bf_size(params, true, tparams.tile_b, tparams.tile_ofm * tparams.outer_ofm * simd); auto batch_threads = threads.first; auto feature_threads = threads.second; @@ -575,6 +605,11 @@ JitConstants FullyConnected_bf_tiled::GetJitConstants(const fully_connected_para size_t tile_k_ofm_packed = tile_k_ofm; size_t quantize_grp_size = get_dynamic_quantize_group_size(params); + if (is_swiglu_fused(params)) { + auto split_length = params.fused_ops[0].GetOpParams()->split_length; + jit.AddConstant(MakeJitConstant("SWIGLU_LENGTH", split_length)); + } + bool add_decompress_scale_post_op = false; WeightsType weights_dt = params.weights.GetDType(); if (weights_dt == WeightsType::UINT4 || weights_dt == WeightsType::INT4) { @@ -723,7 +758,7 @@ JitConstants FullyConnected_bf_tiled::GetJitConstants(const fully_connected_para jit.AddConstant(MakeJitConstant("BATCH_SIZE", "(OUTPUT_BATCH_NUM)")); } - if (!params.fused_ops.empty()) { + if (!params.fused_ops.empty() && !is_swiglu_fused(params)) { std::vector idx_order_scalar = { "(out_b + bi)", "(out_f + sglid)", "0", "0" }; std::vector idx_order_vec = { "(out_b + bi)", "(out_f + sglid + fi * SIMD)", "0", "0" }; if (params.outputs[0].GetLayout() == DataLayout::bfyx) { @@ -828,7 +863,7 @@ KernelsData FullyConnected_bf_tiled::GetTunedKernelsDataByIndex(const Params &pa auto output_f = get_output_aligned_bf_size(fc_params, false).second; WeightsLayout weights_layout = WeightsLayout::os_iyx_osv16; - if (fc_params.compressed && fc_params.inputs[0].GetDType() == Datatype::F16 + if (!is_swiglu_fused(fc_params) && fc_params.compressed && fc_params.inputs[0].GetDType() == Datatype::F16 && (fc_params.weights.GetLayout() == WeightsLayout::oiyx || fc_params.weights.GetLayout() == WeightsLayout::os_is_yx_osv64_isv2) && (fc_params.weights.GetDType() == WeightsType::INT4 || fc_params.weights.GetDType() == WeightsType::UINT4) && is_weight_horizontal(fc_params, output_f)) { diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.h index cbbf52adf344ce..1093c7377bf76f 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.h @@ -76,7 +76,8 @@ class FullyConnected_bf_tiled : public FullyConnectedKernelBase { std::vector GetSupportedFusedOps() const override { return { FusedOpType::ACTIVATION, FusedOpType::ELTWISE, - FusedOpType::QUANTIZE }; + FusedOpType::QUANTIZE, + FusedOpType::SWIGLU }; } JitConstants GetJitConstants(const fully_connected_params& params, const DispatchData& dispatchData) const override; bool Validate(const Params& params) const override; diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/swiglu/swiglu_kernel_base.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/swiglu/swiglu_kernel_base.h index 2f5c046690f78d..62db07e540e429 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/swiglu/swiglu_kernel_base.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/swiglu/swiglu_kernel_base.h @@ -21,6 +21,15 @@ struct swiglu_params : public base_params { int32_t split_to_glu_idx; }; +struct swiglu_fuse_params : fuse_params { + explicit swiglu_fuse_params(int32_t axis, size_t split_lengths) + : fuse_params(KernelType::SWIGLU), + axis(axis), + split_length(split_lengths) {} + int32_t axis; + size_t split_length; +}; + class SwiGLUKernelBase : public KernelBaseOpenCL { public: using KernelBaseOpenCL::KernelBaseOpenCL; diff --git a/src/plugins/intel_gpu/src/plugin/ops/swiglu.cpp b/src/plugins/intel_gpu/src/plugin/ops/swiglu.cpp index 5df2cafd41a41f..23b44dcc1a4677 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/swiglu.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/swiglu.cpp @@ -21,7 +21,7 @@ static void CreateGLUOp(ProgramBuilder& p, const std::shared_ptr& op) { if (p.use_new_shape_infer()) { auto prim = cldnn::swiglu(primitive_name, inputs[0], - op->get_axis(), + (op->get_axis() < 0 ? op->get_input_partial_shape(0).size() + op->get_axis() : op->get_axis()), op->get_split_lengths(), op->get_glu_type(), op->get_split_to_glu_idx(), @@ -31,7 +31,7 @@ static void CreateGLUOp(ProgramBuilder& p, const std::shared_ptr& op) { } else { auto prim = cldnn::swiglu(primitive_name, inputs[0], - op->get_axis(), + (op->get_axis() < 0 ? op->get_input_partial_shape(0).size() + op->get_axis() : op->get_axis()), op->get_split_lengths(), op->get_glu_type(), op->get_split_to_glu_idx(), diff --git a/src/plugins/intel_gpu/src/plugin/transformations/fc_horizontal_fusion.cpp b/src/plugins/intel_gpu/src/plugin/transformations/fc_horizontal_fusion.cpp index fcb339531c1883..327de1424c34c9 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/fc_horizontal_fusion.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/fc_horizontal_fusion.cpp @@ -18,16 +18,25 @@ namespace ov { namespace intel_gpu { -FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() { +FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion(bool fuse_mlp_swiglu) { using namespace ov::pass::pattern; - auto is_target_pattern = [](const Output& output) { + GPU_DEBUG_GET_INSTANCE(debug_config); + // Three FCs connected to the same input + size_t min_num_fcs_to_fuse = 3; + // Note: + // For cldnn, two fcs in mlp will be fused at horizontal fc fusion, and then swiglu will be fused at prepare_primitive_fusion + // i.e., eltwise((fc + swish), fc) => fused_fc + swiglu => fused_fc_swilgu + // Onednn gemms are to be handled in a different way (TBD) + if (fuse_mlp_swiglu) + min_num_fcs_to_fuse = 2; + auto is_target_pattern = [min_num_fcs_to_fuse](const Output& output) { + const int max_num_fcs_to_fuse = 3; // Currently this pass targets only compressed FCs (QKV) on dynamic generative models // inputs: input, weight, bias, scale, [zp] // Bias/scale/zp are constant or none // if it is not constant, the only allowed cases are Constant => convert // All FCs have same # of valid inputs (e.g., if one of the fc has zp, all fcs have zp) - auto is_constant = [](const std::shared_ptr node) { if (std::dynamic_pointer_cast(node)) return true; @@ -40,9 +49,7 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() { auto is_placeholder = [](const std::shared_ptr node) { return std::dynamic_pointer_cast(node); }; - // Three FCs connected to the same input - const int min_num_fcs_to_fuse = 3; - const int max_num_fcs_to_fuse = 3; + const auto& fc = std::dynamic_pointer_cast(output.get_node_shared_ptr()); const auto& input = fc->get_input_node_shared_ptr(0); if (!fc->get_input_partial_shape(0).is_dynamic()) diff --git a/src/plugins/intel_gpu/src/plugin/transformations/fc_horizontal_fusion.hpp b/src/plugins/intel_gpu/src/plugin/transformations/fc_horizontal_fusion.hpp index b6a852354bad8d..67abaa3df54357 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/fc_horizontal_fusion.hpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/fc_horizontal_fusion.hpp @@ -12,7 +12,7 @@ namespace intel_gpu { class FullyConnectedHorizontalFusion: public ov::pass::MatcherPass { public: OPENVINO_RTTI("FullyConnectedHorizontalFusion", "0"); - FullyConnectedHorizontalFusion(); + FullyConnectedHorizontalFusion(bool fuse_mlp_swiglu = false); }; } // namespace intel_gpu diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index fcb88560944854..e47ccbb09a9c43 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -913,12 +913,18 @@ void TransformationsPipeline::apply(std::shared_ptr func) { manager.register_pass(); bool disable_horizontal_fc_fusion = false; + bool disable_fc_swiglu_fusion = false; GPU_DEBUG_GET_INSTANCE(debug_config); GPU_DEBUG_IF(debug_config->disable_horizontal_fc_fusion == 1) disable_horizontal_fc_fusion = true; - + GPU_DEBUG_IF(debug_config->disable_fc_swiglu_fusion == 1) + disable_fc_swiglu_fusion = true; + // mlp fusion is only supported for cldnn on high performant GPUis + bool fuse_mlp_swiglu = !device_info.supports_immad && + device_info.execution_units_count >= 128 && + !disable_fc_swiglu_fusion; if (!disable_horizontal_fc_fusion) - manager.register_pass(); + manager.register_pass(fuse_mlp_swiglu); // ZP should not be folded for FC. But still, ZP should be folded for Gather. // Therefore, run MarkDequantizationSubgraph again to fold ZP constant. diff --git a/src/plugins/intel_gpu/src/runtime/debug_configuration.cpp b/src/plugins/intel_gpu/src/runtime/debug_configuration.cpp index 4a68355e1bc8ba..65ca31f16c720c 100644 --- a/src/plugins/intel_gpu/src/runtime/debug_configuration.cpp +++ b/src/plugins/intel_gpu/src/runtime/debug_configuration.cpp @@ -191,6 +191,7 @@ static void print_help_messages() { message_list.emplace_back("OV_GPU_DynamicQuantizeGroupSize", "Specify a group size of dynamic quantization to enable " "dynamic quantization for Fully-connected primitive."); message_list.emplace_back("OV_GPU_DisableHorizontalFCFusion", "Disable horizontal fc fusion"); + message_list.emplace_back("OV_GPU_DisableFCSwigluFusion", "Disable fc + swiglu fusion"); message_list.emplace_back("OV_GPU_DumpIteration", "Dump n-th execution of network, separated by space."); message_list.emplace_back("OV_GPU_MemPreallocationOptions", "Controls buffer pre-allocation feature. Expects 4 values separated by space in " "the following order: number of iterations for pre-allocation(int), max size of single iteration in bytes(int), " @@ -259,7 +260,8 @@ debug_configuration::debug_configuration() , use_usm_host(0) , use_kv_cache_compression(-1) , dynamic_quantize_group_size(DYNAMIC_QUANTIZE_GROUP_SIZE_NOT_SET) - , disable_horizontal_fc_fusion(0) { + , disable_horizontal_fc_fusion(0) + , disable_fc_swiglu_fusion(0) { #ifdef GPU_DEBUG_CONFIG get_gpu_debug_env_var("Help", help); get_common_debug_env_var("Verbose", verbose); @@ -314,6 +316,7 @@ debug_configuration::debug_configuration() get_gpu_debug_env_var("KVCacheCompression", use_kv_cache_compression); get_gpu_debug_env_var("DynamicQuantizeGroupSize", dynamic_quantize_group_size); get_gpu_debug_env_var("DisableHorizontalFCFusion", disable_horizontal_fc_fusion); + get_gpu_debug_env_var("DisableFCSwigluFusion", disable_fc_swiglu_fusion); std::string dump_iteration_str; get_gpu_debug_env_var("DumpIteration", dump_iteration_str); std::string mem_preallocation_params_str; diff --git a/src/plugins/intel_gpu/tests/unit/fusions/fully_connected_fusion_test.cpp b/src/plugins/intel_gpu/tests/unit/fusions/fully_connected_fusion_test.cpp index 5e9b5134fb3802..09e164742f3fd9 100644 --- a/src/plugins/intel_gpu/tests/unit/fusions/fully_connected_fusion_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/fusions/fully_connected_fusion_test.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -73,7 +74,7 @@ class FullyConnectedFusingTest : public ::BaseFusingTest{ + fully_connected_test_params{ CASE_FC_FP16_INT4_SWIGLU_1, 2, 3 }, + fully_connected_test_params{ CASE_FC_FP16_INT4_SWIGLU_2, 2, 3 }, + fully_connected_test_params{ CASE_FC_FP16_INT4_SWIGLU_3, 2, 3 }, +})); + class fc_imad_int8_eltwise_add_ocl_dynamic : public FullyConnectedFusingTest { public: void run_test() { diff --git a/src/plugins/intel_gpu/tests/unit/fusions/fusion_test_common.hpp b/src/plugins/intel_gpu/tests/unit/fusions/fusion_test_common.hpp index eb0f63c651e50d..c469925083b775 100644 --- a/src/plugins/intel_gpu/tests/unit/fusions/fusion_test_common.hpp +++ b/src/plugins/intel_gpu/tests/unit/fusions/fusion_test_common.hpp @@ -147,6 +147,12 @@ class BaseFusingTest : public ::testing::TestWithParam { } else if (l.data_type == data_types::i8) { VF rnd_vec(s.count(), static_cast(fill_value)); set_values(prim, rnd_vec); + } else if (l.data_type == data_types::u4) { + VF rnd_vec(s.count()/2, static_cast(fill_value)); + set_values(prim, rnd_vec); + } else if (l.data_type == data_types::i4) { + VF rnd_vec(s.count()/2, static_cast(fill_value)); + set_values(prim, rnd_vec); } else { throw std::runtime_error("get_mem: Unsupported precision"); } @@ -186,6 +192,12 @@ class BaseFusingTest : public ::testing::TestWithParam { } else if (l.data_type == data_types::u8) { VF rnd_vec = rg.generate_random_1d(s.count(), min, max); set_values(prim, rnd_vec); + } else if (l.data_type == data_types::i4) { + VF rnd_vec = rg.generate_random_1d(s.count()/2, min, max); + set_values(prim, rnd_vec); + } else if (l.data_type == data_types::u4) { + VF rnd_vec = rg.generate_random_1d(s.count()/2, min, max); + set_values(prim, rnd_vec); } return prim;