diff --git a/cmake/external/abseil-cpp.cmake b/cmake/external/abseil-cpp.cmake
index c01195c99e28d..6c5c4b21f5c58 100644
--- a/cmake/external/abseil-cpp.cmake
+++ b/cmake/external/abseil-cpp.cmake
@@ -27,14 +27,18 @@ FetchContent_Declare(
URL ${DEP_URL_abseil_cpp}
URL_HASH SHA1=${DEP_SHA1_abseil_cpp}
PATCH_COMMAND ${ABSL_PATCH_COMMAND}
- FIND_PACKAGE_ARGS NAMES absl
+ FIND_PACKAGE_ARGS 20240116 NAMES absl
)
onnxruntime_fetchcontent_makeavailable(abseil_cpp)
FetchContent_GetProperties(abseil_cpp)
set(ABSEIL_SOURCE_DIR ${abseil_cpp_SOURCE_DIR})
+# abseil_cpp_SOURCE_DIR is non-empty if we build it from source
message(STATUS "Abseil source dir:" ${ABSEIL_SOURCE_DIR})
-
+# abseil_cpp_VERSION is non-empty if we find a preinstalled ABSL
+if(abseil_cpp_VERSION)
+ message(STATUS "Abseil version:" ${abseil_cpp_VERSION})
+endif()
if (GDK_PLATFORM)
# Abseil considers any partition that is NOT in the WINAPI_PARTITION_APP a viable platform
# for Win32 symbolize code (which depends on dbghelp.lib); this logic should really be flipped
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 8092c26da651a..67bfe48327e14 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -421,7 +421,7 @@ Do not modify directly.*
|Transpose|*in* data:**T**
*out* transposed:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|Trilu|*in* input:**T**
*in* k:**tensor(int64)**
*out* output:**T**|14+|**T** = tensor(double), tensor(float), tensor(int64)|
+|Trilu|*in* input:**T**
*in* k:**tensor(int64)**
*out* output:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int64)|
|Unique|*in* X:**T**
*out* Y:**T**
*out* indices:**tensor(int64)**
*out* inverse_indices:**tensor(int64)**
*out* counts:**tensor(int64)**|11+|**T** = tensor(double), tensor(float), tensor(int64), tensor(int8), tensor(string)|
|Unsqueeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* expanded:**T**
or
*in* data:**T**
*out* expanded:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
diff --git a/include/onnxruntime/core/common/logging/isink.h b/include/onnxruntime/core/common/logging/isink.h
index a67777d4ccc8b..fd011e71611fc 100644
--- a/include/onnxruntime/core/common/logging/isink.h
+++ b/include/onnxruntime/core/common/logging/isink.h
@@ -6,12 +6,15 @@
#include
#include "core/common/logging/logging.h"
+#include "core/common/logging/sink_types.h"
namespace onnxruntime {
namespace logging {
class ISink {
public:
- ISink() = default;
+ explicit ISink(SinkType type = SinkType::BaseSink) : type_(type) {}
+
+ SinkType GetType() const { return type_; }
/**
Sends the message to the sink.
@@ -32,6 +35,8 @@ class ISink {
virtual ~ISink() = default;
private:
+ SinkType type_;
+
// Make Code Analysis happy by disabling all for now. Enable as needed.
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ISink);
diff --git a/include/onnxruntime/core/common/logging/logging.h b/include/onnxruntime/core/common/logging/logging.h
index f62053a5e44ab..55b5c25d1a222 100644
--- a/include/onnxruntime/core/common/logging/logging.h
+++ b/include/onnxruntime/core/common/logging/logging.h
@@ -14,10 +14,10 @@
#include "core/common/common.h"
#include "core/common/profiler_common.h"
#include "core/common/logging/capture.h"
-#include "core/common/logging/severity.h"
-
#include "core/common/logging/macros.h"
-
+#include "core/common/logging/severity.h"
+#include "core/common/logging/sink_types.h"
+#include "core/platform/ort_mutex.h"
#include "date/date.h"
/*
@@ -167,6 +167,23 @@ class LoggingManager final {
*/
static bool HasDefaultLogger() { return nullptr != s_default_logger_; }
+ /**
+ Gets the default instance of the LoggingManager.
+ */
+ static LoggingManager* GetDefaultInstance();
+
+ /**
+ Removes a Sink if one is present
+ */
+ void RemoveSink(SinkType sinkType);
+
+ /**
+ Adds a Sink to the current sink creating a CompositeSink if necessary
+ Sinks types must be unique
+ @param severity The severity level for the new Sink
+ */
+ bool AddSinkOfType(SinkType sinkType, std::function()> sinkFactory, logging::Severity severity);
+
/**
Change the minimum severity level for log messages to be output by the default logger.
@param severity The severity.
@@ -214,7 +231,10 @@ class LoggingManager final {
void CreateDefaultLogger(const std::string& logger_id);
std::unique_ptr sink_;
- const Severity default_min_severity_;
+#ifdef _WIN32
+ mutable OrtMutex sink_mutex_;
+#endif
+ Severity default_min_severity_;
const bool default_filter_user_data_;
const int default_max_vlog_level_;
bool owns_default_logger_;
@@ -362,8 +382,8 @@ unsigned int GetProcessId();
/**
If the ONNXRuntimeTraceLoggingProvider ETW Provider is enabled, then adds to the existing logger.
*/
-std::unique_ptr EnhanceLoggerWithEtw(std::unique_ptr existingLogger, logging::Severity originalSeverity,
- logging::Severity etwSeverity);
+std::unique_ptr EnhanceSinkWithEtw(std::unique_ptr existingSink, logging::Severity originalSeverity,
+ logging::Severity etwSeverity);
/**
If the ONNXRuntimeTraceLoggingProvider ETW Provider is enabled, then can override the logging level.
diff --git a/include/onnxruntime/core/common/logging/sink_types.h b/include/onnxruntime/core/common/logging/sink_types.h
new file mode 100644
index 0000000000000..a99b0fca58d9d
--- /dev/null
+++ b/include/onnxruntime/core/common/logging/sink_types.h
@@ -0,0 +1,11 @@
+#pragma once
+
+namespace onnxruntime {
+namespace logging {
+enum class SinkType {
+ BaseSink,
+ CompositeSink,
+ EtwSink
+};
+} // namespace logging
+} // namespace onnxruntime
diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md
index 3af4942c2e4aa..919b005ec4c21 100644
--- a/js/web/docs/webgpu-operators.md
+++ b/js/web/docs/webgpu-operators.md
@@ -74,6 +74,7 @@ Do not modify directly.*
| Not | ai.onnx(1+) | |
| Pad | ai.onnx(2-10,11-12,13-17,18,19+) | |
| Pow | ai.onnx(7-11,12,13-14,15+) | |
+| QuickGelu | com.microsoft(1+) | |
| Range | ai.onnx(11+) | |
| Reciprocal | ai.onnx(6-12,13+) | |
| ReduceL1 | ai.onnx(1-10,11-12,13-17,18+) | |
diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md
index 1df40b71a00fa..966c93a85ae2a 100644
--- a/js/web/docs/webnn-operators.md
+++ b/js/web/docs/webnn-operators.md
@@ -19,7 +19,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| BatchNormalization | ai.onnx(7-8, 9-13, 14, 15+) | batchNormalization | ✗ | ✓ | Only supports 'training_mode' value is 0, one output |
| Cast | ai.onnx(7-8, 9-12, 13-18, 19-20, 21+) | cast | ✗ | ✓ | |
| Ceil | ai.onnx(7-12, 13+) | ceil | ✓ | ✓ | |
-| Clip | ai.onnx(7-10, 11, 12, 13+) | clamp | ✓ | ✓ | |
+| Clip | ai.onnx(7-10, 11, 12, 13+) | clamp | ✓ | ✓ | WebNN CPU backend only supports 3 specific ranges: [0.0, infinity], [-1.0, 1.0], [0.0, 6.0] (Chromium issue: https://issues.chromium.org/issues/326156496) |
| Concat | ai.onnx(7-10, 11-12, 13+) | concat | ✓ | ✓ | |
| Conv | ai.onnx(7-10, 11+) | conv2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight). WebNN CPU requires the 'W' (weight) input to be a constant |
| ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d | ✓ | ✗ | Only supports 3-D or 4-D input and 'W' (weight). |
@@ -50,7 +50,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| LessOrEqual | ai.onnx(12-15, 16+) | lesserOrEqual | ✗ | ✓ | |
| Log | ai.onnx(7-12, 13+) | log | ✗ | ✓ | |
| LpPool | ai.onnx(7-10, 11-17, 18+) | l2Pool2d | ✗ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'p' value is 2 |
-| MatMul | ai.onnx(7-8, 9-12, 13+) | matmul | ✓ | ✓ | WebNN CPU doesn't support broadcasting for MatMul |
+| MatMul | ai.onnx(7-8, 9-12, 13+) | matmul | ✓ | ✓ | |
| Max | ai.onnx(7, 8-11, 12, 13+) | max | ✓ | ✓ | |
| MaxPool | ai.onnx(7, 8-9, 10, 11, 12+) | maxPool2d | ✓ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'storage_order' != 1, one output |
| Min | ai.onnx(7, 8-11, 12, 13+) | min | ✓ | ✓ | |
@@ -73,7 +73,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| ReduceSumSquare | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceSumSquare | ✗ | ✓ | Input 'axes' if present should be a constant |
| Relu | ai.onnx(7-12, 13, 14+) | relu | ✓ | ✓ | |
| Reshape | ai.onnx(7-12, 13, 14-18, 19-20, 21+) | reshape | ✓ | ✓ | Input 'shape' should be a constant, 0 dimension value in 'shape' is not supported |
-| Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d | ✓ | ✓ | Only supports 4-D input, exclude_outside != 0, input 'scales' and 'sizes' if present must be a constant, WebNN CPU backend only supports 'linear' mode, WebNN GPU backend only supports 'linear' and 'nearest' modes |
+| Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d | ✓ | ✓ | Only supports 4-D input, exclude_outside != 0, input 'scales' and 'sizes' if present must be a constant, 'linear' and 'nearest' modes |
| Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice | ✓ | ✓ | |
| Sigmoid | ai.onnx(7-12, 13+) | sigmoid | ✓ | ✓ | |
| Softplus | ai.onnx(7+) | softplus | ✗ | ✓ | |
@@ -81,7 +81,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| Sin | ai.onnx(7+) | sin | ✗ | ✓ | |
| Slice | ai.onnx(7-9, 10, 11-12, 13+) | slice | ✓ | ✓ | Input 'starts', 'ends', 'axes', and 'steps' if present must be a constant, only supports 'steps' value 1 |
| Softmax | ai.onnx(7-10, 11-12, 13+) | softmax | ✓ | ✓ | Only supports input rank >= 2 |
-| Split | ai.onnx(7-10, 11-12, 13-17, 18+) | split | ✓ | ✓ | Input 'split' if present should be a constant, WebNN CPU backend only supports up to 4 outputs |
+| Split | ai.onnx(7-10, 11-12, 13-17, 18+) | split | ✓ | ✓ | Input 'split' if present should be a constant |
| Sqrt | ai.onnx(7-12, 13+) | sqrt | ✓ | ✓ | |
| Squeeze | ai.onnx(7-10, 11-12, 13-20, 21+) | reshape | ✓ | ✓ | Input 'axes' if present should be a constant |
| Sub | ai.onnx(7-12, 13, 14+) | sub | ✓ | ✓ | |
diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
index 2d2f345d0c273..ce5b4455fde60 100644
--- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
+++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
@@ -107,6 +107,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new
['Not', [unaryOps.not]],
['Pad', [pad]],
['Pow', [binaryOps.pow]],
+ ['QuickGelu', [unaryOps.quickgelu, unaryOps.parseAlphaAttributes]],
['Range', [range]],
['Reciprocal', [unaryOps.reciprocal]],
['ReduceMin', [reduceMin]],
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
index 5f105c745739e..12ba2a10cdf9f 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
@@ -314,3 +314,31 @@ export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttrib
export const log = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Log', 'log'));
};
+
+export const quickGeluImpl = (varType: string, alpha: number) => `
+const alpha = vec4<${varType}>(${alpha});
+const one = ${varType}(1.0);
+const zero = ${varType}(0.0);
+
+fn quick_gelu_impl(x: vec4<${varType}>) -> vec4<${varType}> {
+ let v = x *alpha;
+ var x1 : vec4<${varType}>;
+ for (var i = 0; i < 4; i = i + 1) {
+ if (v[i] >= zero) {
+ x1[i] = one / (one + exp(-v[i]));
+ } else {
+ x1[i] = one - one / (one + exp(v[i]));
+ }
+ }
+ return x * x1;
+}
+`;
+
+export const quickGeluExpression = (x: string) => `quick_gelu_impl(${x})`;
+
+export const quickgelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
+ const dType = tensorTypeToWsglValueType(context.inputs[0].dataType);
+ context.compute(createElementwiseProgramInfo(
+ context.inputs[0], 'QuickGelu', quickGeluExpression, quickGeluImpl(dType, attributes.alpha), attributes.cacheKey,
+ context.inputs[0].dataType));
+};
diff --git a/js/web/test/data/ops/quick-gelu.jsonc b/js/web/test/data/ops/quick-gelu.jsonc
new file mode 100644
index 0000000000000..a6e618fe34796
--- /dev/null
+++ b/js/web/test/data/ops/quick-gelu.jsonc
@@ -0,0 +1,46 @@
+[
+ {
+ "name": "QuickGelu test",
+ "operator": "QuickGelu",
+ "opset": { "domain": "com.microsoft", "version": 1 },
+ "cases": [
+ {
+ "name": "[2x4]",
+ "inputs": [
+ {
+ "data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, -0.8],
+ "dims": [2, 4],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [0.0542447, 0.116857, 0.187484, 0.265566, 0.350388, 0.441123, 0.53689, 0.636815],
+ "dims": [2, 4],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "[3x5]",
+ "inputs": [
+ {
+ "data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, -1.5],
+ "dims": [3, 5],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [
+ 0.0542447, 0.116857, 0.187484, 0.265566, 0.350388, 0.845795, 1.9356, 2.98192, 3.99558, 4.99899, 0.953383,
+ 1.0622, 1.17178, 1.2817, 1.39166
+ ],
+ "dims": [3, 5],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ }
+]
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
index 34f57c1655cc2..8ae7b4589d677 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
@@ -69,9 +69,8 @@ class AttentionCPUBase : public AttentionBase {
BufferUniquePtr mask_data_buffer(mask_data, BufferDeleter(allocator));
const int32_t* mask_index_data = mask_index != nullptr ? mask_index->Data() : nullptr;
- gsl::span mask_index_dims = mask_index != nullptr
- ? mask_index->Shape().GetDims()
- : gsl::span{};
+ gsl::span mask_index_dims =
+ mask_index != nullptr ? mask_index->Shape().GetDims() : gsl::span{};
const T* past_data = past != nullptr ? past->Data() : nullptr;
T* present_data = present != nullptr ? present->MutableData() : nullptr;
const T* past_key_data = past_key != nullptr ? past_key->Data() : nullptr;
@@ -84,22 +83,19 @@ class AttentionCPUBase : public AttentionBase {
relative_position_bias_data = relative_position_bias->Data();
}
- ComputeAttentionProbs(static_cast(attention_probs), Q, K,
- mask_index_data, mask_index_dims, static_cast(mask_data), causal,
- batch_size, sequence_length, kv_sequence_length, past_sequence_length,
- qk_head_size == 0 ? v_head_size : qk_head_size, past_data, past_key_data,
- present_data, present_key_data, tp, relative_position_bias_data);
+ ComputeAttentionProbs(static_cast(attention_probs), Q, K, mask_index_data, mask_index_dims,
+ static_cast(mask_data), causal, batch_size, sequence_length, kv_sequence_length,
+ past_sequence_length, qk_head_size == 0 ? v_head_size : qk_head_size, past_data,
+ past_key_data, present_data, present_key_data, tp, relative_position_bias_data);
// Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
auto out_tmp_data =
allocator->Alloc(SafeInt(batch_size) * num_heads_ * sequence_length * v_head_size * sizeof(T));
BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(std::move(allocator)));
- ComputeVxAttentionScore(output->MutableData(), static_cast(out_tmp_data),
- static_cast(attention_probs), V,
- batch_size, sequence_length, kv_sequence_length, past_sequence_length,
- v_head_size, v_hidden_size, past_data, past_value_data,
- present_data, present_value_data, tp);
+ ComputeVxAttentionScore(output->MutableData(), static_cast(out_tmp_data), static_cast(attention_probs),
+ V, batch_size, sequence_length, kv_sequence_length, past_sequence_length, v_head_size,
+ v_hidden_size, past_data, past_value_data, present_data, present_value_data, tp);
return Status::OK();
}
@@ -138,16 +134,17 @@ class AttentionCPUBase : public AttentionBase {
{
// mask_data is nullptr when mask_index is nullptr and not unidirectional, otherwise its shape is BxSxT
if (mask_data != nullptr) {
- PrepareMask(mask_index, mask_index_dims, mask_data,
- causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_);
+ PrepareMask(mask_index, mask_index_dims, mask_data, causal, batch_size, sequence_length, past_sequence_length,
+ mask_filter_value_);
}
const int loop_len = batch_size * num_heads_;
const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_;
TensorOpCost unit_cost;
- const size_t probs_matrix_bytes = SafeInt(sequence_length) * total_sequence_length * sizeof(T);
- unit_cost.compute_cycles = static_cast(2 * sequence_length * head_size * total_sequence_length);
+ const ptrdiff_t probs_matrix_bytes = SafeInt(sequence_length) * total_sequence_length * sizeof(T);
+ unit_cost.compute_cycles =
+ static_cast(SafeInt(2) * sequence_length * head_size * total_sequence_length);
unit_cost.bytes_loaded = static_cast((sequence_length + total_sequence_length) * head_size * sizeof(T));
unit_cost.bytes_stored = static_cast(probs_matrix_bytes);
@@ -172,15 +169,13 @@ class AttentionCPUBase : public AttentionBase {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const int batch_index = static_cast(i) / num_heads_;
- const int output_offset = static_cast(i) * sequence_length * total_sequence_length;
- const int mask_offset = batch_index * sequence_length * total_sequence_length;
+ const ptrdiff_t output_offset = SafeInt(i) * sequence_length * total_sequence_length;
+ const ptrdiff_t mask_offset = SafeInt(batch_index) * sequence_length * total_sequence_length;
T* output = attention_probs + output_offset;
// Broadcast mask data: (Bx)SxT -> (BxNx)SxT
if (mask_data != nullptr) {
- memcpy(output,
- mask_data + mask_offset,
- probs_matrix_bytes);
+ memcpy(output, mask_data + mask_offset, probs_matrix_bytes);
}
const T* k = K + kv_input_chunk_length * i;
@@ -197,8 +192,8 @@ class AttentionCPUBase : public AttentionBase {
// B: K' (B x N x) T x H (B x N x) H x T H x T
// C: attention_probs (B x N x) S x T (B x N x) S x T S x T
math::Gemm(CblasNoTrans, CblasTrans, sequence_length, total_sequence_length, head_size, alpha,
- Q + q_input_chunk_length * i, k, mask_data != nullptr ? 1.0f : 0.0f,
- output, nullptr);
+ Q + q_input_chunk_length * i, k, mask_data != nullptr ? 1.0f : 0.0f, output,
+ nullptr);
if (relative_position_bias_data != nullptr) {
for (int j = 0; j < sequence_length * total_sequence_length; j++) {
@@ -249,8 +244,10 @@ class AttentionCPUBase : public AttentionBase {
// The cost of Gemm
TensorOpCost unit_cost;
- unit_cost.compute_cycles = static_cast(2 * sequence_length * v_head_size * total_sequence_length);
- unit_cost.bytes_loaded = static_cast((sequence_length + v_head_size) * total_sequence_length * sizeof(T));
+ unit_cost.compute_cycles =
+ static_cast(SafeInt(2) * sequence_length * v_head_size * total_sequence_length);
+ unit_cost.bytes_loaded =
+ static_cast(SafeInt(sequence_length + v_head_size) * total_sequence_length * sizeof(T));
unit_cost.bytes_stored = static_cast(sequence_length * v_head_size * sizeof(T));
if (present || present_value) {
@@ -264,35 +261,36 @@ class AttentionCPUBase : public AttentionBase {
unit_cost.bytes_loaded += bytes_to_copy_trans_all;
unit_cost.bytes_stored += bytes_to_copy_trans_all;
- ThreadPool::TryParallelFor(tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
- for (std::ptrdiff_t i = begin; i != end; ++i) {
- const T* v = V + kv_input_chunk_length * i;
- if (nullptr != present) {
- // Concatenate past_V and V: (BxNx)PxH_v, (BxNx)LxH_v -> (BxNx)TxH_v
- v = ConcatStateChunk(past, v, present, past_chunk_length, present_chunk_length, i);
- } else if (nullptr != present_value) {
- v = ConcatStateChunk(past_value, v, present_value, past_chunk_length, present_chunk_length, i);
- }
+ ThreadPool::TryParallelFor(
+ tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
+ for (std::ptrdiff_t i = begin; i != end; ++i) {
+ const T* v = V + kv_input_chunk_length * i;
+ if (nullptr != present) {
+ // Concatenate past_V and V: (BxNx)PxH_v, (BxNx)LxH_v -> (BxNx)TxH_v
+ v = ConcatStateChunk(past, v, present, past_chunk_length, present_chunk_length, i);
+ } else if (nullptr != present_value) {
+ v = ConcatStateChunk(past_value, v, present_value, past_chunk_length, present_chunk_length, i);
+ }
- T* current_tmp_data = reinterpret_cast(tmp_buffer) + q_input_chunk_length * i;
- ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * total_sequence_length * i;
- math::MatMul(sequence_length, v_head_size, total_sequence_length,
- attention_probs + attention_probs_offset,
- v, current_tmp_data, nullptr);
-
- // Transpose: out(B, S, N, H_v) -> out_tmp(B, N, S, H_v)
- const int batch_index = static_cast(i / num_heads_);
- const int head_index = static_cast(i % num_heads_);
- T* src = current_tmp_data;
- ptrdiff_t dest_offset = (SafeInt(batch_index) * sequence_length * num_heads_ + head_index) * v_head_size;
- T* dest = output + dest_offset;
- for (int j = 0; j < sequence_length; j++) {
- memcpy(dest, src, bytes_to_copy_trans);
- src += v_head_size;
- dest += v_hidden_size;
- }
- }
- });
+ T* current_tmp_data = reinterpret_cast(tmp_buffer) + q_input_chunk_length * i;
+ ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * total_sequence_length * i;
+ math::MatMul(sequence_length, v_head_size, total_sequence_length,
+ attention_probs + attention_probs_offset, v, current_tmp_data, nullptr);
+
+ // Transpose: out(B, S, N, H_v) -> out_tmp(B, N, S, H_v)
+ const int batch_index = static_cast(i / num_heads_);
+ const int head_index = static_cast(i % num_heads_);
+ T* src = current_tmp_data;
+ ptrdiff_t dest_offset =
+ (SafeInt(batch_index) * sequence_length * num_heads_ + head_index) * v_head_size;
+ T* dest = output + dest_offset;
+ for (int j = 0; j < sequence_length; j++) {
+ memcpy(dest, src, bytes_to_copy_trans);
+ src += v_head_size;
+ dest += v_hidden_size;
+ }
+ }
+ });
}
};
diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
index fa80efffc9ea1..6b0c5f395cab0 100644
--- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
@@ -63,17 +63,16 @@ class GQAAttentionBase : public AttentionBase {
bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data;
const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;
- ComputeAttentionProbs(static_cast(attention_probs), Q, k,
- seqlens_k->Data(),
- batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache,
- head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, tp);
+ ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size,
+ sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data,
+ present_key_data, past_present_share_buffer, packed_qkv, tp);
// Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
- ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs),
- v, seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache,
- seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data,
- past_present_share_buffer, packed_qkv, tp);
+ ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, seqlens_k->Data(),
+ batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size,
+ hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv,
+ tp);
return Status::OK();
}
@@ -98,7 +97,9 @@ class GQAAttentionBase : public AttentionBase {
bool packed_qkv, // whether Q, K, V are packed
ThreadPool* tp) const { // thread pool
const bool is_prompt = sequence_length != 1;
- const int packed_batch_stride = packed_qkv ? (num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : 0;
+ const ptrdiff_t packed_batch_stride =
+ packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size
+ : SafeInt(0);
const int kv_num_heads_factor = num_heads_ / kv_num_heads_;
const size_t q_input_chunk_length = static_cast(sequence_length) * head_size; // S x H
const size_t kv_input_chunk_length = static_cast(sequence_length) * head_size; // L x H
@@ -113,9 +114,12 @@ class GQAAttentionBase : public AttentionBase {
const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_;
TensorOpCost unit_cost;
- const size_t probs_matrix_bytes = SafeInt(sequence_length) * present_buffer_sequence_length * sizeof(T);
- unit_cost.compute_cycles = static_cast(2 * sequence_length * head_size * present_buffer_sequence_length);
- unit_cost.bytes_loaded = static_cast((sequence_length + present_buffer_sequence_length) * head_size * sizeof(T));
+ const ptrdiff_t probs_matrix_bytes =
+ SafeInt(sequence_length) * present_buffer_sequence_length * sizeof(T);
+ unit_cost.compute_cycles =
+ static_cast(SafeInt(2) * sequence_length * head_size * present_buffer_sequence_length);
+ unit_cost.bytes_loaded =
+ static_cast((sequence_length + present_buffer_sequence_length) * head_size * sizeof(T));
unit_cost.bytes_stored = static_cast(probs_matrix_bytes);
unit_cost.bytes_loaded += static_cast(probs_matrix_bytes);
@@ -131,11 +135,12 @@ class GQAAttentionBase : public AttentionBase {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const int batch_index = static_cast(i) / num_heads_;
const int head_index = static_cast(i) % num_heads_;
- const int past_seqlen = sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length;
+ const int past_seqlen =
+ sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length;
const size_t past_chunk_length = static_cast(past_seqlen) * head_size;
const int total_seqlen = seqlens_k[batch_index] + 1;
- const int output_offset = static_cast(i) * sequence_length * present_buffer_sequence_length;
+ const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length;
T* output = attention_probs + output_offset;
const T* k;
@@ -161,11 +166,9 @@ class GQAAttentionBase : public AttentionBase {
} else {
q = Q + q_input_chunk_length * i;
}
- math::GemmEx(CblasNoTrans, CblasTrans,
- sequence_length, total_seqlen, head_size, alpha,
- q, head_size, k, head_size,
- 0.0f /*bata*/,
- output, present_buffer_sequence_length, nullptr);
+ math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q,
+ head_size, k, head_size, 0.0f /*bata*/, output, present_buffer_sequence_length,
+ nullptr);
// compute Softmax
T* output_softmax = output;
@@ -175,7 +178,8 @@ class GQAAttentionBase : public AttentionBase {
for (int total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) {
output_softmax[total_seq_id] = 0.f;
}
- ComputeAttentionSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1, local_window_size_ + 1, nullptr);
+ ComputeAttentionSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1,
+ local_window_size_ + 1, nullptr);
} else {
ComputeAttentionSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr);
}
@@ -208,7 +212,9 @@ class GQAAttentionBase : public AttentionBase {
bool packed_qkv, // whether Q, K, V are packed
ThreadPool* tp) const {
const bool is_prompt = sequence_length != 1;
- const int packed_batch_stride = packed_qkv ? (num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : 0;
+ const ptrdiff_t packed_batch_stride =
+ packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size
+ : SafeInt(0);
const int kv_num_heads_factor = num_heads_ / kv_num_heads_;
const int kv_input_chunk_length = sequence_length * head_size; // L x H
const size_t past_buff_chunk_length = static_cast(past_buffer_sequence_length) * head_size; // L x H
@@ -220,8 +226,10 @@ class GQAAttentionBase : public AttentionBase {
// The cost of Gemm
TensorOpCost unit_cost;
- unit_cost.compute_cycles = static_cast(2 * sequence_length * head_size * present_buffer_sequence_length);
- unit_cost.bytes_loaded = static_cast((sequence_length + head_size) * present_buffer_sequence_length * sizeof(T));
+ unit_cost.compute_cycles =
+ static_cast(SafeInt(2) * sequence_length * head_size * present_buffer_sequence_length);
+ unit_cost.bytes_loaded = static_cast(SafeInt(sequence_length + head_size) *
+ present_buffer_sequence_length * sizeof(T));
unit_cost.bytes_stored = static_cast(sequence_length * head_size * sizeof(T));
if (present_value) {
@@ -235,39 +243,37 @@ class GQAAttentionBase : public AttentionBase {
unit_cost.bytes_loaded += bytes_to_copy_trans_all;
unit_cost.bytes_stored += bytes_to_copy_trans_all;
- ThreadPool::TryParallelFor(tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
- for (std::ptrdiff_t i = begin; i != end; ++i) {
- const int batch_index = static_cast(i / num_heads_);
- const int head_index = static_cast(i % num_heads_);
- const int past_seqlen = sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length;
- const size_t past_chunk_length = static_cast(past_seqlen) * head_size;
- const int total_seqlen = seqlens_k[batch_index] + 1;
+ ThreadPool::TryParallelFor(
+ tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
+ for (std::ptrdiff_t i = begin; i != end; ++i) {
+ const int batch_index = static_cast(i / num_heads_);
+ const int head_index = static_cast(i % num_heads_);
+ const int past_seqlen =
+ sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length;
+ const size_t past_chunk_length = static_cast(past_seqlen) * head_size;
+ const int total_seqlen = seqlens_k[batch_index] + 1;
+
+ const T* v;
+ if (packed_qkv) {
+ v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor);
+ } else {
+ v = V + kv_input_chunk_length * (i / kv_num_heads_factor);
+ }
+ if (nullptr != present_value) {
+ v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length,
+ past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer,
+ i / kv_num_heads_factor);
+ }
- const T* v;
- if (packed_qkv) {
- v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor);
- } else {
- v = V + kv_input_chunk_length * (i / kv_num_heads_factor);
- }
- if (nullptr != present_value) {
- v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length,
- past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer,
- i / kv_num_heads_factor);
- }
+ T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
+ ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i;
- T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
- ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i;
-
- math::GemmEx(CblasNoTrans,
- CblasNoTrans,
- sequence_length, head_size, total_seqlen,
- 1.f, /*alpha*/
- attention_probs + attention_probs_offset, present_buffer_sequence_length,
- v, head_size,
- 0.0f /*beta*/,
- output_current, hidden_size, nullptr);
- }
- });
+ math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen,
+ 1.f, /*alpha*/
+ attention_probs + attention_probs_offset, present_buffer_sequence_length, v,
+ head_size, 0.0f /*beta*/, output_current, hidden_size, nullptr);
+ }
+ });
}
};
diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc
index 9d8f79c67d8a4..7bc3414c89978 100644
--- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc
@@ -16,6 +16,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, GroupQueryAttention);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, QuickGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, RotaryEmbedding);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization);
@@ -38,6 +39,7 @@ Status RegisterJsContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/js/quick_gelu.cc b/onnxruntime/contrib_ops/js/quick_gelu.cc
new file mode 100644
index 0000000000000..4bb4d5afd4109
--- /dev/null
+++ b/onnxruntime/contrib_ops/js/quick_gelu.cc
@@ -0,0 +1,23 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "quick_gelu.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace js {
+
+using onnxruntime::js::JsepSupportedFloatTypes;
+
+ONNX_OPERATOR_KERNEL_EX(
+ QuickGelu,
+ kMSDomain,
+ 1,
+ kJsExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("T", JsepSupportedFloatTypes()),
+ QuickGelu);
+
+} // namespace js
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/js/quick_gelu.h b/onnxruntime/contrib_ops/js/quick_gelu.h
new file mode 100644
index 0000000000000..51e39e2718d51
--- /dev/null
+++ b/onnxruntime/contrib_ops/js/quick_gelu.h
@@ -0,0 +1,24 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/providers/js/js_kernel.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace js {
+
+using onnxruntime::js::JsKernel;
+
+class QuickGelu final : public JsKernel {
+ public:
+ explicit QuickGelu(const OpKernelInfo& info) : JsKernel(info) {
+ float alpha = info.GetAttrOrDefault("alpha", 1.0);
+ JSEP_INIT_KERNEL_ATTRIBUTE(QuickGelu, ({"alpha" : $1}), alpha);
+ }
+};
+
+} // namespace js
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/core/common/logging/logging.cc b/onnxruntime/core/common/logging/logging.cc
index eac9a7fa08081..ad6f666a2d989 100644
--- a/onnxruntime/core/common/logging/logging.cc
+++ b/onnxruntime/core/common/logging/logging.cc
@@ -9,11 +9,11 @@
#include "core/common/exceptions.h"
#include "core/common/logging/isink.h"
#include "core/common/logging/logging.h"
+#include "core/common/logging/sinks/composite_sink.h"
#ifdef _WIN32
#include
#include "core/platform/windows/logging/etw_sink.h"
-#include "core/common/logging/sinks/composite_sink.h"
#else
#include
#if defined(__MACH__) || defined(__wasm__) || defined(_AIX)
@@ -22,10 +22,10 @@
#include
#endif
#endif
-#include "core/platform/ort_mutex.h"
#if __FreeBSD__
#include // Use thr_self() syscall under FreeBSD to get thread id
+#include "logging.h"
#endif
namespace onnxruntime {
@@ -52,6 +52,10 @@ static std::atomic& DefaultLoggerManagerInstance() noexcept {
return default_instance;
}
+LoggingManager* LoggingManager::GetDefaultInstance() {
+ return static_cast(DefaultLoggerManagerInstance().load());
+}
+
// GSL_SUPRESS(i.22) is broken. Ignore the warnings for the static local variables that are trivial
// and should not have any destruction order issues via pragmas instead.
// https://developercommunity.visualstudio.com/content/problem/249706/gslsuppress-does-not-work-for-i22-c-core-guideline.html
@@ -66,6 +70,7 @@ static OrtMutex& DefaultLoggerMutex() noexcept {
}
Logger* LoggingManager::s_default_logger_ = nullptr;
+OrtMutex sink_mutex_;
#ifdef _MSC_VER
#pragma warning(pop)
@@ -245,27 +250,27 @@ unsigned int GetProcessId() {
#endif
}
-std::unique_ptr EnhanceLoggerWithEtw(std::unique_ptr existingLogger, logging::Severity originalSeverity,
- logging::Severity etwSeverity) {
+std::unique_ptr EnhanceSinkWithEtw(std::unique_ptr existing_sink, logging::Severity original_severity,
+ logging::Severity etw_severity) {
#ifdef _WIN32
auto& manager = EtwRegistrationManager::Instance();
if (manager.IsEnabled()) {
auto compositeSink = std::make_unique();
- compositeSink->AddSink(std::move(existingLogger), originalSeverity);
- compositeSink->AddSink(std::make_unique(), etwSeverity);
+ compositeSink->AddSink(std::move(existing_sink), original_severity);
+ compositeSink->AddSink(std::make_unique(), etw_severity);
return compositeSink;
} else {
- return existingLogger;
+ return existing_sink;
}
#else
// On non-Windows platforms, just return the existing logger
- (void)originalSeverity;
- (void)etwSeverity;
- return existingLogger;
+ (void)original_severity;
+ (void)etw_severity;
+ return existing_sink;
#endif // _WIN32
}
-Severity OverrideLevelWithEtw(Severity originalSeverity) {
+Severity OverrideLevelWithEtw(Severity original_severity) {
#ifdef _WIN32
auto& manager = logging::EtwRegistrationManager::Instance();
if (manager.IsEnabled() &&
@@ -273,7 +278,50 @@ Severity OverrideLevelWithEtw(Severity originalSeverity) {
return manager.MapLevelToSeverity();
}
#endif // _WIN32
- return originalSeverity;
+ return original_severity;
+}
+
+bool LoggingManager::AddSinkOfType(SinkType sink_type, std::function()> sinkFactory,
+ logging::Severity severity) {
+ std::lock_guard guard(sink_mutex_);
+ if (sink_->GetType() != SinkType::CompositeSink) {
+ // Current sink is not a composite, create a new composite sink and add the current sink to it
+ auto new_composite = std::make_unique();
+ new_composite->AddSink(std::move(sink_), default_min_severity_); // Move the current sink into the new composite
+ sink_ = std::move(new_composite); // Now sink_ is pointing to the new composite
+ }
+ // Adjust the default minimum severity level to accommodate new sink needs
+ default_min_severity_ = std::min(default_min_severity_, severity);
+ if (s_default_logger_ != nullptr) {
+ s_default_logger_->SetSeverity(default_min_severity_);
+ }
+ CompositeSink* current_composite = static_cast(sink_.get());
+ if (current_composite->HasType(sink_type)) {
+ return false; // Sink of this type already exists, do not add another
+ }
+
+ current_composite->AddSink(sinkFactory(), severity);
+ return true;
+}
+
+void LoggingManager::RemoveSink(SinkType sink_type) {
+ std::lock_guard guard(sink_mutex_);
+
+ if (sink_->GetType() == SinkType::CompositeSink) {
+ auto composite_sink = static_cast(sink_.get());
+
+ Severity newSeverity = composite_sink->RemoveSink(sink_type);
+
+ if (composite_sink->HasOnlyOneSink()) {
+ // If only one sink remains, replace the CompositeSink with this single sink
+ sink_ = composite_sink->GetRemoveSingleSink();
+ }
+
+ default_min_severity_ = newSeverity;
+ if (s_default_logger_ != nullptr) {
+ s_default_logger_->SetSeverity(default_min_severity_);
+ }
+ }
}
} // namespace logging
diff --git a/onnxruntime/core/common/logging/sinks/composite_sink.h b/onnxruntime/core/common/logging/sinks/composite_sink.h
index 9d18eb527ffdd..e4a85f7d556bc 100644
--- a/onnxruntime/core/common/logging/sinks/composite_sink.h
+++ b/onnxruntime/core/common/logging/sinks/composite_sink.h
@@ -23,7 +23,17 @@ class CompositeSink : public ISink {
/// Initializes a new instance of the class.
/// Use AddSink to add sinks.
///
- CompositeSink() {}
+ CompositeSink() : ISink(SinkType::CompositeSink) {}
+
+ ///
+ /// Check if the composite sink contains a sink of the specified type.
+ ///
+ bool HasType(SinkType sink_type) const {
+ return std::any_of(sinks_with_severity_.begin(), sinks_with_severity_.end(),
+ [&](const auto& sink_pair) {
+ return sink_pair.first->GetType() == sink_type;
+ });
+ }
///
/// Adds a sink. Takes ownership of the sink (so pass unique_ptr by value).
@@ -37,11 +47,48 @@ class CompositeSink : public ISink {
}
///
- /// Gets a const reference to the collection of sinks and min severity for that sink
+ /// Remove a sink of the specified type.
+ ///
+ /// Sink type to remove
+ /// Minimum severity of the remaining sinks
+ logging::Severity RemoveSink(SinkType sink_type) {
+ logging::Severity severity = Severity::kFATAL; // default if we end up with no sinks
+
+ // find entries to remove and the minimum severity of the remaining sinks
+ auto entries_to_remove = std::remove_if(sinks_with_severity_.begin(), sinks_with_severity_.end(),
+ [&](const auto& entry) {
+ if (entry.first->GetType() == sink_type) {
+ return true;
+ } else {
+ severity = std::min(severity, entry.second);
+ return false;
+ }
+ });
+
+ sinks_with_severity_.erase(entries_to_remove, sinks_with_severity_.end());
+
+ return severity;
+ }
+
+ ///
+ /// Check if there's only one sink left
+ ///
+ /// True if only 1 sink remaining
+ bool HasOnlyOneSink() const {
+ return sinks_with_severity_.size() == 1;
+ }
+
+ ///
+ /// If one sink is remaining then returns it and empties the composite sink
///
- /// A const reference to the vector pair of unique_ptr to ISink and severity.
- const std::vector, logging::Severity>>& GetSinks() const {
- return sinks_with_severity_;
+ /// If one sink remains then returns the sink, otherwise nullptr
+ std::unique_ptr GetRemoveSingleSink() {
+ if (HasOnlyOneSink()) {
+ auto single_sink = std::move(sinks_with_severity_.begin()->first);
+ sinks_with_severity_.clear();
+ return single_sink;
+ }
+ return nullptr;
}
private:
diff --git a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp
index ba6b417050e2d..2d4d791c3a000 100644
--- a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp
+++ b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp
@@ -2,6 +2,9 @@
#include "mlasi.h"
#include
+// NOTE: Vector commands (e.g., vec_xst) need C-style casting to support various compiler versions.
+// ONNX Runtime CI pipelines do not build with all compiler versions.
+
template
void
MLASCALL
@@ -194,7 +197,7 @@ Return Value:
auto ShortVector1 = vec_pack(IntegerVector2, IntegerVector3);
auto CharVector = vec_pack(ShortVector0, ShortVector1);
- vec_xst(CharVector, 0, reinterpret_cast(&TmpOutput[0]));
+ vec_xst(CharVector, 0, (int8_t *)(&TmpOutput[0]));
MlasPackInt4Elements(Output++, TmpOutput[0], TmpOutput[1]);
MlasPackInt4Elements(Output++, TmpOutput[2], TmpOutput[3]);
diff --git a/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp b/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp
index 633349e800875..a67be1dbfa710 100644
--- a/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp
+++ b/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp
@@ -67,7 +67,7 @@ MlasGemmQuantFixupZeroPointB(
}
-template
+template
void
MlasGemmQuantCopyPackA8x8(
MLAS_GEMM_QUANT_KERNEL_POWER10::PackedAType* D,
@@ -75,11 +75,10 @@ MlasGemmQuantCopyPackA8x8(
size_t lda,
size_t CountM,
size_t CountK,
- int32_t* RowSumBuffer,
- bool AIsSigned
+ int32_t* RowSumBuffer
)
{
- const uint8_t Flip = (AIsSigned ? 0 : 0x80);
+ constexpr uint8_t Flip = (AIsSigned ? 0 : 0x80);
Vtype vmask = reinterpret_cast(vec_splats(Flip));
typedef __vector signed char vec_t;
@@ -106,66 +105,74 @@ MlasGemmQuantCopyPackA8x8(
Vtype a3 = *reinterpret_cast(&a[lda * 2]);
Vtype a4 = *reinterpret_cast(&a[lda * 3]);
Vtype vx =
- reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a1),
+ reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1),
reinterpret_cast<__vector int>(a2)));
Vtype vx1 =
- reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a3),
+ reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3),
reinterpret_cast<__vector int>(a4)));
Vtype vx2 =
- reinterpret_cast(vec_mergeo (reinterpret_cast<__vector int>(a1),
+ reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a1),
reinterpret_cast<__vector int>(a2)));
Vtype vx3 =
- reinterpret_cast(vec_mergeo (reinterpret_cast<__vector int>(a3),
+ reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a3),
reinterpret_cast<__vector int>(a4)));
- Vtype vx4 = vec_xxpermdi (vx, vx1, 0);
- Vtype vx5 = vec_xxpermdi (vx2, vx3, 0);
- Vtype vx6 = vec_xxpermdi (vx, vx1, 3);
- Vtype vx7 = vec_xxpermdi (vx2, vx3, 3);
+ Vtype vx4 = vec_xxpermdi(vx, vx1, 0);
+ Vtype vx5 = vec_xxpermdi(vx2, vx3, 0);
+ Vtype vx6 = vec_xxpermdi(vx, vx1, 3);
+ Vtype vx7 = vec_xxpermdi(vx2, vx3, 3);
a1 = *reinterpret_cast(&a[lda*4]);
a2 = *reinterpret_cast(&a[lda*5]);
a3 = *reinterpret_cast(&a[lda*6]);
a4 = *reinterpret_cast(&a[lda*7]);
vx =
- reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a1),
+ reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1),
reinterpret_cast<__vector int>(a2)));
vx1 =
- reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a3),
+ reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3),
reinterpret_cast<__vector int>(a4)));
vx2 =
- reinterpret_cast(vec_mergeo (reinterpret_cast<__vector int>(a1),
+ reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a1),
reinterpret_cast<__vector int>(a2)));
vx3 =
- reinterpret_cast(vec_mergeo (reinterpret_cast<__vector int>(a3),
+ reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a3),
reinterpret_cast<__vector int>(a4)));
- Vtype vx8 = vec_xxpermdi (vx, vx1, 0);
- Vtype vx9 = vec_xxpermdi (vx2, vx3, 0);
- Vtype vx10 = vec_xxpermdi (vx, vx1, 3);
- Vtype vx11 = vec_xxpermdi (vx2, vx3, 3);
+ Vtype vx8 = vec_xxpermdi(vx, vx1, 0);
+ Vtype vx9 = vec_xxpermdi(vx2, vx3, 0);
+ Vtype vx10 = vec_xxpermdi(vx, vx1, 3);
+ Vtype vx11 = vec_xxpermdi(vx2, vx3, 3);
vec_t vxx =
- reinterpret_cast(vec_sub (vx4, vmask));
- vsum = vec_sum4s (vxx, vsum);
+ AIsSigned ? reinterpret_cast(vx4) :
+ reinterpret_cast(vec_sub(vx4, vmask));
+ vsum = vec_sum4s(vxx, vsum);
*reinterpret_cast(&D[0]) = vxx;
- vxx = reinterpret_cast(vec_sub (vx5, vmask));
- vsum = vec_sum4s (vxx, vsum);
+ vxx = AIsSigned ? reinterpret_cast(vx5) :
+ reinterpret_cast(vec_sub(vx5, vmask));
+ vsum = vec_sum4s(vxx, vsum);
*reinterpret_cast(&D[16]) = vxx;
- vxx = reinterpret_cast(vec_sub (vx6, vmask));
- vsum = vec_sum4s (vxx, vsum);
+ vxx = AIsSigned ? reinterpret_cast(vx6) :
+ reinterpret_cast(vec_sub(vx6, vmask));
+ vsum = vec_sum4s(vxx, vsum);
*reinterpret_cast(&D[32]) = vxx;
- vxx = reinterpret_cast(vec_sub (vx7, vmask));
- vsum = vec_sum4s (vxx, vsum);
+ vxx = AIsSigned ? reinterpret_cast(vx7) :
+ reinterpret_cast(vec_sub(vx7, vmask));
+ vsum = vec_sum4s(vxx, vsum);
*reinterpret_cast(&D[48]) = vxx;
- vxx = reinterpret_cast(vec_sub (vx8, vmask));
+ vxx = AIsSigned ? reinterpret_cast(vx8) :
+ reinterpret_cast(vec_sub(vx8, vmask));
*reinterpret_cast(&D[64]) = vxx;
- vsum2 = vec_sum4s (vxx, vsum2);
- vxx = reinterpret_cast(vec_sub (vx9, vmask));
+ vsum2 = vec_sum4s(vxx, vsum2);
+ vxx = AIsSigned ? reinterpret_cast(vx9) :
+ reinterpret_cast(vec_sub(vx9, vmask));
*reinterpret_cast(&D[80]) = vxx;
- vsum2 = vec_sum4s (vxx, vsum2);
- vxx = reinterpret_cast(vec_sub (vx10, vmask));
+ vsum2 = vec_sum4s(vxx, vsum2);
+ vxx = AIsSigned ? reinterpret_cast(vx10) :
+ reinterpret_cast(vec_sub(vx10, vmask));
*reinterpret_cast(&D[96]) = vxx;
- vsum2 = vec_sum4s (vxx, vsum2);
- vxx = reinterpret_cast(vec_sub (vx11, vmask));
+ vsum2 = vec_sum4s(vxx, vsum2);
+ vxx = AIsSigned ? reinterpret_cast(vx11) :
+ reinterpret_cast(vec_sub(vx11, vmask));
*reinterpret_cast(&D[112]) = vxx;
- vsum2 = vec_sum4s (vxx, vsum2);
+ vsum2 = vec_sum4s(vxx, vsum2);
D += 16 * 8;
a += 16;
y -= 16;
@@ -179,16 +186,18 @@ MlasGemmQuantCopyPackA8x8(
int a4 = *reinterpret_cast(&a[lda*3]);
__vector int vx1 = { a1, a2, a3, a4};
vec_t vx =
- reinterpret_cast(vec_sub (reinterpret_cast(vx1), vmask));
- vsum = vec_sum4s (vx, vsum);
+ AIsSigned ? reinterpret_cast(vx1) :
+ reinterpret_cast(vec_sub(reinterpret_cast(vx1), vmask));
+ vsum = vec_sum4s(vx, vsum);
*reinterpret_cast(&D[0]) = vx;
a1 = *reinterpret_cast(&a[lda*4]);
a2 = *reinterpret_cast(&a[lda*5]);
a3 = *reinterpret_cast(&a[lda*6]);
a4 = *reinterpret_cast(&a[lda*7]);
__vector int vx2 = { a1, a2, a3, a4};
- vx = reinterpret_cast(vec_sub (reinterpret_cast(vx2), vmask));
- vsum2 = vec_sum4s (vx, vsum2);
+ vx = AIsSigned ? reinterpret_cast(vx2) :
+ reinterpret_cast(vec_sub(reinterpret_cast(vx2), vmask));
+ vsum2 = vec_sum4s(vx, vsum2);
if (CountK & 3) {
if (yval >= 12) {
*reinterpret_cast(&D[64]) = vx;
@@ -225,10 +234,10 @@ MlasGemmQuantCopyPackA8x8(
}
if (y >= 1)
{
- Vtype a1 = reinterpret_cast(vec_splats(Flip));
- Vtype a2 = reinterpret_cast(vec_splats(Flip));
- Vtype a3 = reinterpret_cast(vec_splats(Flip));
- Vtype a4 = reinterpret_cast(vec_splats(Flip));
+ Vtype a1 = vmask;
+ Vtype a2 = vmask;
+ Vtype a3 = vmask;
+ Vtype a4 = vmask;
a1[0] = a[0];
a2[0] = a[lda];
a3[0] = a[lda * 2];
@@ -246,20 +255,21 @@ MlasGemmQuantCopyPackA8x8(
a4[2] = a[lda * 3 + 2];
}
Vtype vx =
- reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a1),
+ reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1),
reinterpret_cast<__vector int>(a2)));
Vtype vx1 =
- reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a3),
+ reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3),
reinterpret_cast<__vector int>(a4)));
- Vtype vx2 = vec_xxpermdi (vx, vx1, 0);
+ Vtype vx2 = vec_xxpermdi(vx, vx1, 0);
vec_t vx3 =
- reinterpret_cast(vec_sub (vx2, vmask));
- vsum = vec_sum4s (vx3, vsum);
+ AIsSigned ? reinterpret_cast(vx2) :
+ reinterpret_cast(vec_sub(vx2, vmask));
+ vsum = vec_sum4s(vx3, vsum);
*reinterpret_cast(&D[0]) = vx3;
- a1 = reinterpret_cast(vec_splats(Flip));
- a2 = reinterpret_cast(vec_splats(Flip));
- a3 = reinterpret_cast(vec_splats(Flip));
- a4 = reinterpret_cast(vec_splats(Flip));
+ a1 = vmask;
+ a2 = vmask;
+ a3 = vmask;
+ a4 = vmask;
a1[0] = a[lda * 4];
a2[0] = a[lda * 5];
a3[0] = a[lda * 6];
@@ -277,14 +287,15 @@ MlasGemmQuantCopyPackA8x8(
a4[2] = a[lda * 7 + 2];
}
vx =
- reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a1),
+ reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1),
reinterpret_cast<__vector int>(a2)));
vx1 =
- reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a3),
+ reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3),
reinterpret_cast<__vector int>(a4)));
- vx2 = vec_xxpermdi (vx, vx1, 0);
- vx3 = reinterpret_cast(vec_sub (vx2, vmask));
- vsum2 = vec_sum4s (vx3, vsum2);
+ vx2 = vec_xxpermdi(vx, vx1, 0);
+ vx3 = AIsSigned ? reinterpret_cast(vx2) :
+ reinterpret_cast(vec_sub(vx2, vmask));
+ vsum2 = vec_sum4s(vx3, vsum2);
if (CountK % 16 >= 12) {
*reinterpret_cast(&D[64]) = vx3;
D += 80;
@@ -327,34 +338,38 @@ MlasGemmQuantCopyPackA8x8(
Vtype a3 = *reinterpret_cast