diff --git a/cmake/patches/abseil/absl_windows.patch b/cmake/patches/abseil/absl_windows.patch index 66ef0c5125a74..584c49d612293 100644 --- a/cmake/patches/abseil/absl_windows.patch +++ b/cmake/patches/abseil/absl_windows.patch @@ -25,17 +25,91 @@ index a6efc98e..8c4de8e7 100644 "/wd4800", ] diff --git a/absl/copts/copts.py b/absl/copts/copts.py -index 0d6c1ec3..75fd935f 100644 +index e6e11949..0aa7d868 100644 --- a/absl/copts/copts.py +++ b/absl/copts/copts.py -@@ -132,10 +132,6 @@ COPT_VARS = { - "/wd4068", # unknown pragma - # qualifier applied to function type has no meaning; ignored - "/wd4180", -- # conversion from 'type1' to 'type2', possible loss of data -- "/wd4244", -- # conversion from 'size_t' to 'type', possible loss of data -- "/wd4267", - # The decorated name was longer than the compiler limit - "/wd4503", - # forcing value to bool 'true' or 'false' (performance warning) +@@ -115,10 +115,6 @@ MSVC_WARNING_FLAGS = [ + "/wd4068", # unknown pragma + # qualifier applied to function type has no meaning; ignored + "/wd4180", +- # conversion from 'type1' to 'type2', possible loss of data +- "/wd4244", +- # conversion from 'size_t' to 'type', possible loss of data +- "/wd4267", + # The decorated name was longer than the compiler limit + "/wd4503", + # forcing value to bool 'true' or 'false' (performance warning) +diff --git a/absl/debugging/symbolize_win32.inc b/absl/debugging/symbolize_win32.inc +index 53a099a1..34d210d6 100644 +--- a/absl/debugging/symbolize_win32.inc ++++ b/absl/debugging/symbolize_win32.inc +@@ -35,15 +35,15 @@ ABSL_NAMESPACE_BEGIN + + static HANDLE process = NULL; + +-void InitializeSymbolizer(const char*) { +- if (process != nullptr) { +- return; +- } ++namespace { ++void InitializeSymbolizerImpl() { ++ + process = GetCurrentProcess(); + + // Symbols are not loaded until a reference is made requiring the + // symbols be loaded. This is the fastest, most efficient way to use + // the symbol handler. ++ + SymSetOptions(SYMOPT_DEFERRED_LOADS | SYMOPT_UNDNAME); + if (!SymInitialize(process, nullptr, true)) { + // GetLastError() returns a Win32 DWORD, but we assign to +@@ -54,6 +54,36 @@ void InitializeSymbolizer(const char*) { + } + } + ++bool LookupAndInitialize(const void* pc, SYMBOL_INFO* symbol) { ++ auto hProcess = (process != NULL) ? process : GetCurrentProcess(); ++ if (SymFromAddr(hProcess, reinterpret_cast(pc), nullptr, symbol) != TRUE) { ++ if (GetLastError() == ERROR_INVALID_HANDLE && process == NULL) { ++ InitializeSymbolizerImpl(); ++ if (SymFromAddr(process, reinterpret_cast(pc), nullptr, symbol) != TRUE) { ++ return false; ++ } ++ } else { ++ return false; ++ } ++ return false; ++ } ++ return true; ++} ++} ++ ++void InitializeSymbolizer(const char*) { ++ if (process != nullptr) { ++ return; ++ } ++ ++ alignas(SYMBOL_INFO) char buf[sizeof(SYMBOL_INFO) + MAX_SYM_NAME]; ++ SYMBOL_INFO* symbol = reinterpret_cast(buf); ++ symbol->SizeOfStruct = sizeof(SYMBOL_INFO); ++ symbol->MaxNameLen = MAX_SYM_NAME; ++ ++ static_cast(LookupAndInitialize(reinterpret_cast(&InitializeSymbolizer), symbol)); ++} ++ + bool Symbolize(const void* pc, char* out, int out_size) { + if (out_size <= 0) { + return false; +@@ -62,9 +92,11 @@ bool Symbolize(const void* pc, char* out, int out_size) { + SYMBOL_INFO* symbol = reinterpret_cast(buf); + symbol->SizeOfStruct = sizeof(SYMBOL_INFO); + symbol->MaxNameLen = MAX_SYM_NAME; +- if (!SymFromAddr(process, reinterpret_cast(pc), nullptr, symbol)) { ++ ++ if(!LookupAndInitialize(pc, symbol)) { + return false; + } ++ + const size_t out_size_t = static_cast(out_size); + strncpy(out, symbol->Name, out_size_t); + if (out[out_size_t - 1] != '\0') { diff --git a/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp b/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp index 955b7c5deee9a..43a12b37e4ffa 100644 --- a/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp +++ b/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp @@ -171,11 +171,9 @@ Return Value: if (k > 0) { Row0AElements0 = a[0]; - Row0AElements1 = a[1]; if (ProcessTwoRows) { Row1AElements0 = a[lda]; - Row1AElements1 = a[lda + 1]; } BElements0 = MlasLoadFloat32x4(B + 0); diff --git a/onnxruntime/core/platform/windows/debug_alloc.cc b/onnxruntime/core/platform/windows/debug_alloc.cc index ff6a059607367..f3520b4f7f7f5 100644 --- a/onnxruntime/core/platform/windows/debug_alloc.cc +++ b/onnxruntime/core/platform/windows/debug_alloc.cc @@ -55,41 +55,67 @@ struct MemoryBlock { }; struct SymbolHelper { - SymbolHelper() noexcept { - SymSetOptions(SymGetOptions() | SYMOPT_DEFERRED_LOADS); - SymInitialize(GetCurrentProcess(), nullptr, true); + HANDLE process_handle_ = GetCurrentProcess(); + bool initialized_ = false; + + bool InitializeWhenNeeded() { + // We try only once + if (!initialized_) { + SymSetOptions(SymGetOptions() | SYMOPT_DEFERRED_LOADS); + // We use GetCurrentProcess() because other libs are likely to use it + if (!SymInitialize(process_handle_, nullptr, true)) { + const unsigned long long error{GetLastError()}; + std::cerr << "SymInitialize() failed: " << error << std::endl; + return false; + } + initialized_ = true; + } + return true; + } + + SymbolHelper() = default; + + static constexpr size_t kInitialBufferSize = sizeof(SYMBOL_INFO) + MAX_SYM_NAME; + + bool LoookupSymAndInitialize(const ULONG_PTR address, char* buffer, size_t buffer_size, SYMBOL_INFO* symbol) { + if (SymFromAddr(process_handle_, address, 0, symbol) != TRUE) { + if (GetLastError() == ERROR_INVALID_HANDLE) { + // Try to initialize first + if (!InitializeWhenNeeded() || SymFromAddr(process_handle_, address, 0, symbol) != TRUE) { + _snprintf_s(buffer, buffer_size, _TRUNCATE, "0x%08IX (Unknown symbol)", address); + return false; + } + } else { + _snprintf_s(buffer, buffer_size, _TRUNCATE, "0x%08IX (Unknown symbol)", address); + return false; + } + } + return true; } void Lookup(std::string& string, const ULONG_PTR address) { - char buffer[2048] = {0}; - Symbol symbol; - if (SymFromAddr(GetCurrentProcess(), address, 0, &symbol) == false) { - _snprintf_s(buffer, _TRUNCATE, "0x%08IX (Unknown symbol)", address); + alignas(SYMBOL_INFO) char buffer[kInitialBufferSize] = {0}; + SYMBOL_INFO* symbol = reinterpret_cast(buffer); + symbol->SizeOfStruct = sizeof(SYMBOL_INFO); + symbol->MaxNameLen = MAX_SYM_NAME; + + if (!LoookupSymAndInitialize(address, buffer, kInitialBufferSize, symbol)) { string.append(buffer); return; } Line line; DWORD displacement; - if (SymGetLineFromAddr(GetCurrentProcess(), address, &displacement, &line) == false) { - _snprintf_s(buffer, _TRUNCATE, "(unknown file & line number): %s", symbol.Name); + if (SymGetLineFromAddr(process_handle_, address, &displacement, &line) == false) { + _snprintf_s(buffer, _TRUNCATE, "(unknown file & line number): %s", symbol->Name); string.append(buffer); return; } - _snprintf_s(buffer, _TRUNCATE, "%s(%d): %s", line.FileName, static_cast(line.LineNumber), symbol.Name); + _snprintf_s(buffer, _TRUNCATE, "%s(%d): %s", line.FileName, static_cast(line.LineNumber), symbol->Name); string.append(buffer); } - struct Symbol : SYMBOL_INFO { - Symbol() noexcept { - SizeOfStruct = sizeof(SYMBOL_INFO); - MaxNameLen = _countof(buffer); - } - - char buffer[1024] = {0}; - }; - struct Line : IMAGEHLP_LINE { Line() noexcept { SizeOfStruct = sizeof(IMAGEHLP_LINE); diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index 908f846ab910b..80617b7b5edaa 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -452,7 +452,7 @@ def quantize_weight_per_channel_impl( return q_weight_name, zp_name, scale_name - def adjust_tensor_ranges(self, softmax_0_to_1=False): + def adjust_tensor_ranges(self): if self.tensors_range is None: return @@ -471,6 +471,6 @@ def adjust_tensor_ranges(self, softmax_0_to_1=False): if not isinstance(td, TensorData): raise TypeError(f"Unexpected type {type(td)} for {node.output[0]!r}.") self.tensors_range[node.input[0]] = td - # Optionally, adjust Softmax to range from 0.0 to 1.0 - elif node.op_type == "Softmax" and softmax_0_to_1: + # Adjust Softmax to range from 0.0 to 1.0 + elif node.op_type == "Softmax": self.tensors_range[node.output[0]] = TensorData(lowest=np.float32(0.0), highest=np.float32(1.0)) diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 253ed5e62b679..4b76de6ecf1cb 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -955,7 +955,7 @@ def calculate_quantization_params(self): if self.tensors_range is None: return None - self.adjust_tensor_ranges(softmax_0_to_1=False) + self.adjust_tensor_ranges() quantization_params = {} for tensor_name in self.tensors_range: diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index bbd22f508da23..c323c6fec545a 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -1098,7 +1098,7 @@ def calc_graph_quant_params(self) -> dict[str, QDQTensorQuantParams]: if self.tensors_range is None: return {} - self.adjust_tensor_ranges(softmax_0_to_1=True) # Ensure Softmax ranges from 0.0 to 1.0 for QDQ models. + self.adjust_tensor_ranges() quantization_params = {} for tensor_name in self.tensors_range: diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py index 4d0d2e68e8983..47b7f35cbdd7c 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py @@ -400,11 +400,7 @@ def main(): sampling_times.append(sampling_end_time - sampling_start_time) all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1) - - # Return early if all batch entries have reached EOS token id current_length += 1 - if torch.all(has_eos) or current_length > max_length: - break # Update inputs for next inference run inputs["input_ids"] = tokens_to_add diff --git a/onnxruntime/python/tools/transformers/onnx_utils.py b/onnxruntime/python/tools/transformers/onnx_utils.py new file mode 100644 index 0000000000000..64fade9369395 --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_utils.py @@ -0,0 +1,55 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from fusion_utils import NumpyHelper +from onnx import ModelProto, TensorProto +from onnx.external_data_helper import set_external_data +from onnx_model import OnnxModel + +from onnxruntime import OrtValue + + +def extract_raw_data_from_model(model: ModelProto): + """ + Extract external data from model and return the external data as a list of tuples (name, value). + Note this function does not handle external data that is not loaded into the model as raw data. + + Args: + model (ModelProto): the model proto to extract external data from. + Returns: + (external_names, external_values): a tuple of two lists of external data names and values. + """ + external_data = [] + onnx_model = OnnxModel(model) + for graph in onnx_model.graphs(): + for initializer in graph.initializer: + name = initializer.name + + if initializer.HasField("raw_data"): + numpy_tensor = NumpyHelper.to_array(initializer) + ort_value = OrtValue.ortvalue_from_numpy(numpy_tensor) + external_data.append((name, ort_value)) + # mimic set_external_data + set_external_data(initializer, location="foo.bin") + initializer.name = name + initializer.ClearField("raw_data") + + return zip(*external_data) + + +def has_external_data(model: ModelProto): + """ + Check if the model has external data. + + Args: + model (ModelProto): the model proto to check for external data. + Returns: + bool: True if the model has external data, False otherwise. + """ + onnx_model = OnnxModel(model) + for graph in onnx_model.graphs(): + for initializer in graph.initializer: + if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL: + return True + return False diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index ce0be6b3449ed..068ccefef7d97 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -21,11 +21,12 @@ import logging import os import tempfile -from typing import Dict, List, Optional +from pathlib import Path +from typing import Dict, List, Optional, Union import coloredlogs from fusion_options import FusionOptions -from onnx import ModelProto, TensorProto, load_model +from onnx import ModelProto, load_model from onnx_model import OnnxModel from onnx_model_bart import BartOnnxModel from onnx_model_bert import BertOnnxModel @@ -40,6 +41,9 @@ from onnx_model_unet import UnetOnnxModel from onnx_model_vae import VaeOnnxModel +import onnxruntime +from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data + logger = logging.getLogger(__name__) # Map model type to tuple: optimizer class, export tools (pytorch, tf2onnx, keras2onnx), and default opt_level @@ -64,7 +68,7 @@ def optimize_by_onnxruntime( - onnx_model_path: str, + onnx_model: Union[str, ModelProto], use_gpu: bool = False, optimized_model_path: Optional[str] = None, opt_level: Optional[int] = 99, @@ -80,7 +84,7 @@ def optimize_by_onnxruntime( Use onnxruntime to optimize model. Args: - onnx_model_path (str): the path of input onnx model. + onnx_model (str | ModelProto): the path of input onnx model or ModelProto. use_gpu (bool): whether the optimized model is targeted to run in GPU. optimized_model_path (str or None): the path of optimized model. opt_level (int): graph optimization level. @@ -95,8 +99,6 @@ def optimize_by_onnxruntime( assert opt_level in [1, 2, 99] from torch import version as torch_version - import onnxruntime - if ( use_gpu and provider is None @@ -105,9 +107,13 @@ def optimize_by_onnxruntime( ) ): logger.error("There is no gpu for onnxruntime to do optimization.") - return onnx_model_path + return onnx_model - model = OnnxModel(load_model(onnx_model_path, load_external_data=False)) + model = ( + OnnxModel(load_model(onnx_model, load_external_data=False)) + if isinstance(onnx_model, str) + else OnnxModel(onnx_model) + ) if model.use_float16() and not use_gpu: logger.warning( "This model uses float16 in the graph, use_gpu=False might cause extra Cast nodes. " @@ -125,7 +131,10 @@ def optimize_by_onnxruntime( sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL if optimized_model_path is None: - path_prefix = onnx_model_path[:-5] # remove .onnx suffix + if isinstance(onnx_model, str): + path_prefix = str(Path(onnx_model).with_suffix("")) # remove .onnx suffix + else: + path_prefix = "optimized_model" optimized_model_path = "{}_o{}_{}.onnx".format(path_prefix, opt_level, "gpu" if use_gpu else "cpu") sess_options.optimized_model_filepath = optimized_model_path @@ -174,7 +183,20 @@ def optimize_by_onnxruntime( else: providers.append("CUDAExecutionProvider") - onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers, **kwargs) + # For large model, extract external data from model and add to session options + if isinstance(onnx_model, ModelProto): + if has_external_data(onnx_model): + raise ValueError( + "ModelProto has external data not loaded into memory, ORT cannot create session. " + "Please load external data before calling this function. " + "See https://onnx.ai/onnx/repo-docs/ExternalData.html for more information." + ) + external_names, external_values = extract_raw_data_from_model(onnx_model) + sess_options.add_external_initializers(list(external_names), list(external_values)) + + # Inference session is only used to optimize the model. + onnx_model = onnx_model.SerializeToString() if isinstance(onnx_model, ModelProto) else onnx_model + onnxruntime.InferenceSession(onnx_model, sess_options, providers=providers, **kwargs) assert os.path.exists(optimized_model_path) and os.path.isfile(optimized_model_path) logger.debug("Save optimized model by onnxruntime to %s", optimized_model_path) @@ -187,7 +209,7 @@ def optimize_by_fusion( num_heads: int = 0, hidden_size: int = 0, optimization_options: Optional[FusionOptions] = None, -): +) -> OnnxModel: """Optimize Model by graph fusion logic. Note that ONNXRuntime graph optimizations (like constant folding) will not be applied. So it is better to enable @@ -241,7 +263,7 @@ def optimize_by_fusion( def optimize_model( - input: str, + input: Union[str, ModelProto], model_type: str = "bert", num_heads: int = 0, hidden_size: int = 0, @@ -252,7 +274,7 @@ def optimize_model( verbose: bool = False, *, provider: Optional[str] = None, -): +) -> OnnxModel: """Optimize Model by OnnxRuntime and/or python fusion logic. ONNX Runtime has graph optimizations (https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html). @@ -275,7 +297,7 @@ def optimize_model( For BERT model, num_heads and hidden_size are optional. For other model types, you need specify these parameters. Args: - input (str): input model path. + input (str | ModelProto): input model path or ModelProto. model_type (str, optional): model type - like bert, bert_tf, bert_keras or gpt2. Defaults to 'bert'. num_heads (int, optional): number of attention heads. Defaults to 0. 0 allows detect the parameter from graph automatically. @@ -298,9 +320,9 @@ def optimize_model( if model_type not in MODEL_TYPES: logger.warning(f"Unsupported model type: {model_type} for optimization, directly return model.") - return OnnxModel(load_model(input)) + return OnnxModel(load_model(input)) if isinstance(input, str) else OnnxModel(input) - (optimizer_class, _producer, default_opt_level) = MODEL_TYPES[model_type] + (optimizer_class, _, default_opt_level) = MODEL_TYPES[model_type] if opt_level is None: opt_level = default_opt_level @@ -316,11 +338,9 @@ def optimize_model( # Auto detect if input model has external data has_external_data_file = False - original_model = load_model(input, load_external_data=False) - for initializer in original_model.graph.initializer: - if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL: - has_external_data_file = True - break + original_model = load_model(input, load_external_data=False) if isinstance(input, str) else input + if has_external_data(original_model): + has_external_data_file = True del original_model if opt_level > 1: @@ -365,7 +385,12 @@ def optimize_model( if only_onnxruntime and not temp_model_path: logger.warning("Please specify a positive value for opt_level when only_onnxruntime is True") - model = load_model(temp_model_path or input) + if temp_model_path is not None: + model = load_model(temp_model_path) + elif isinstance(input, str): + model = load_model(input) + else: + model = input if only_onnxruntime: optimizer = optimizer_class(model, num_heads, hidden_size) diff --git a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc index 4148e63d58619..f4d2f68d4d8b5 100644 --- a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc +++ b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc @@ -36,7 +36,7 @@ struct ATenOperator { size_t return_size; std::vector ret_kinds; - c10::IValue ToIValueArgument(const DLManagedTensor* dlpack, size_t index) const { + c10::IValue ToIValueArgument(DLManagedTensor* dlpack, size_t index) const { TORCH_INTERNAL_ASSERT(index < argument_size); bool is_optional = is_optional_arguments[index]; TORCH_INTERNAL_ASSERT(dlpack || is_optional || default_values[index] || @@ -57,7 +57,7 @@ struct ATenOperator { c10::IValue i_value; // Create the torch tensor from this DLPack no matter we need it or not below, // so that the dlpack's deleter will be triggered when torch tensor is out of scope. - at::Tensor tensor = at::fromDLPack(const_cast(dlpack)); + at::Tensor tensor = at::fromDLPack(dlpack); switch (elem_kinds[index]) { case c10::TypeKind::TensorType: { i_value = is_optional ? c10::IValue(c10::optional(tensor)) : c10::IValue(tensor); diff --git a/onnxruntime/test/python/transformers/test_onnx_utils.py b/onnxruntime/test/python/transformers/test_onnx_utils.py new file mode 100644 index 0000000000000..974991359795e --- /dev/null +++ b/onnxruntime/test/python/transformers/test_onnx_utils.py @@ -0,0 +1,38 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import unittest + +import numpy +from onnx import ModelProto, TensorProto, helper +from onnx.external_data_helper import set_external_data + +from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data + + +class TestOnnxUtils(unittest.TestCase): + def test_extract_raw_data_from_model(self): + model = self._get_model_proto_with_raw_data(False) + external_names, external_values = extract_raw_data_from_model(model) + self.assertEqual(list(external_names), ["inputs"]) + self.assertEqual(len(external_values), 1) + self.assertEqual(external_values[0].numpy(), [0.0]) + + def test_has_external_data(self): + model = self._get_model_proto_with_raw_data() + self.assertTrue(has_external_data(model)) + + def test_has_external_data_with_no_external_data(self): + model = self._get_model_proto_with_raw_data(False) + self.assertFalse(has_external_data(model)) + + def _get_model_proto_with_raw_data(self, has_external_data: bool = True) -> ModelProto: + input = helper.make_tensor_value_info("inputs", TensorProto.FLOAT, [None]) + output = helper.make_tensor_value_info("outputs", TensorProto.FLOAT, [None]) + raw_data = numpy.array([0.0], dtype=numpy.float32).tobytes() + tensor = helper.make_tensor("inputs", TensorProto.FLOAT, [1], raw_data, True) + if has_external_data: + set_external_data(tensor, location="foo.bin") + node = helper.make_node("Identity", inputs=["inputs"], outputs=["outputs"]) + return helper.make_model(helper.make_graph([node], "graph", [input], [output], initializer=[tensor])) diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.cc b/orttraining/orttraining/core/framework/gradient_graph_builder.cc index d66591318d5c7..2ee4b5e1a173d 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.cc +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.cc @@ -210,6 +210,11 @@ NodeSet GradientGraphBuilder::ReverseBFSWithStopGradient(const NodeSet& nodes) c continue; } const NodeArg* node_arg = n->InputDefs()[edge_it->GetDstArgIndex()]; + if (!node_arg) { + LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() + << " of node: " << n->Name() << " because it is not found in the graph."; + continue; + } const auto [is_tensor_type, is_allowed_type_for_grad, type] = IsAllowedForGradient(graph_, node_arg); if (is_tensor_type) { if (!is_allowed_type_for_grad) { diff --git a/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc b/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc index 2aade8c9bc1f9..61fc8d5492c2b 100644 --- a/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc +++ b/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc @@ -44,7 +44,7 @@ Status InsertSoftmaxCrossEntropyLossOutput::Apply(Graph& graph, Node& node, Rewr t.mutable_tensor_type()->mutable_shape()->CopyFrom(*X->Shape()); // log probability should have the same shape as logits. } - NodeArg& node_arg = graph.GetOrCreateNodeArg(X->Name() + "_log_prob", &t); + NodeArg& node_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(X->Name() + "_log_prob"), &t); outputs.push_back(&node_arg);