Skip to content

Commit

Permalink
add chanel first format support for rand=3 Deconv to use amx fp16 kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
liubo-intel committed Oct 23, 2024
1 parent a4b7f35 commit bdf3b0a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 20 deletions.
22 changes: 8 additions & 14 deletions src/plugins/intel_cpu/src/nodes/deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,12 @@ std::vector<memory::format_tag> Deconvolution::getAvailableFormatsForDims(const
else if (dims.getRank() == 2)
return {memory::format_tag::nc};
else if (dims.getRank() == 3)
return {memory::format_tag::tnc, memory::format_tag::ntc,
memory::format_tag::ncw, memory::format_tag::nCw8c, memory::format_tag::nCw16c };
return {memory::format_tag::tnc,
memory::format_tag::ntc,
memory::format_tag::nwc,
memory::format_tag::ncw,
memory::format_tag::nCw8c,
memory::format_tag::nCw16c};
else if (dims.getRank() == 4)
return {memory::format_tag::nchw, memory::format_tag::nChw8c,
memory::format_tag::nChw16c, memory::format_tag::nhwc };
Expand Down Expand Up @@ -461,18 +465,8 @@ void Deconvolution::getSupportedDescriptors() {
outputDataType = DnnlExtensionUtils::ElementTypeToDataType(outPrecision);
if (inputDataType == memory::data_type::bf16 || outputDataType == memory::data_type::bf16)
inputDataType = outputDataType = memory::data_type::bf16;

if (inputDataType == memory::data_type::f16 || outputDataType == memory::data_type::f16) {
// TODO: remove this limitation after adding support for f16 in oneDNN MFDNN-12580
if (std::any_of(deconvAttrs.stride.begin(), deconvAttrs.stride.end(), [](ptrdiff_t stride) {
return stride != 1;
})) {
inputDataType = outputDataType = memory::data_type::f32;
} else {
inputDataType = outputDataType = memory::data_type::f16;
}
}

if (inputDataType == memory::data_type::f16 || outputDataType == memory::data_type::f16)
inputDataType = outputDataType = memory::data_type::f16;
if (!fusedWith.empty()) {
outputDataType = DnnlExtensionUtils::ElementTypeToDataType(fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(0));
}
Expand Down
6 changes: 0 additions & 6 deletions src/plugins/intel_cpu/src/nodes/executors/acl/acl_convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,6 @@ bool ACLConvertExecutorBuilder::isSupported(const ConvertParams& convertParams,
DEBUG_LOG("NECopy does not support source precision: ", convertParams.srcPrc.to_string());
return false;
}
auto srcDataLayout = getAclDataLayoutByMemoryDesc(srcDesc);
auto dstDataLayout = getAclDataLayoutByMemoryDesc(dstDesc);
if (srcDataLayout == DataLayout::UNKNOWN || dstDataLayout == DataLayout::UNKNOWN) {
DEBUG_LOG("NECopy does not support source or destination layout");
return false;
}
if ((convertParams.srcPrc == ov::element::i8 && !one_of(convertParams.dstPrc,
ov::element::i16,
ov::element::i32,
Expand Down

0 comments on commit bdf3b0a

Please sign in to comment.