Skip to content

Commit

Permalink
Add support for SGD optimizer in minimal build (#19901)
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani authored Mar 14, 2024
1 parent 1fb6cbd commit 226f60f
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 77 deletions.
27 changes: 15 additions & 12 deletions orttraining/orttraining/python/training/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def generate_artifacts(
requires_grad: Optional[List[str]] = None,
frozen_params: Optional[List[str]] = None,
loss: Optional[Union[LossType, onnxblock.Block]] = None,
optimizer: Optional[OptimType] = None,
optimizer: Optional[Union[OptimType, onnxblock.Block]] = None,
artifact_directory: Optional[Union[str, bytes, os.PathLike]] = None,
prefix: str = "",
ort_format: bool = False,
Expand All @@ -64,8 +64,8 @@ def generate_artifacts(
model: The base model to be used for gradient graph generation.
requires_grad: List of names of model parameters that require gradient computation
frozen_params: List of names of model parameters that should be frozen.
loss: The loss function enum to be used for training. If None, no loss node is added to the graph.
optimizer: The optimizer enum to be used for training. If None, no optimizer model is generated.
loss: The loss function enum or onnxblock to be used for training. If None, no loss node is added to the graph.
optimizer: The optimizer enum or onnxblock to be used for training. If None, no optimizer model is generated.
artifact_directory: The directory to save the generated artifacts.
If None, the current working directory is used.
prefix: The prefix to be used for the generated artifacts. If not specified, no prefix is used.
Expand Down Expand Up @@ -219,14 +219,6 @@ def _export_to_ort_format(model_path, output_dir, ort_format, custom_op_library_
logging.info("No optimizer enum provided. Skipping optimizer model generation.")
return

if not isinstance(optimizer, OptimType):
raise RuntimeError(
f"Unknown optimizer provided {type(optimizer)}. Expected optimizer to be of type "
"onnxruntime.training.artifacts.OptimType."
)

logging.info("Optimizer enum provided: %s", optimizer.name)

opset_version = None
for domain in model.opset_import:
if domain.domain == "" or domain.domain == "ai.onnx":
Expand All @@ -235,8 +227,19 @@ def _export_to_ort_format(model_path, output_dir, ort_format, custom_op_library_

optim_model = None
optim_blocks = {OptimType.AdamW: onnxblock.optim.AdamW, OptimType.SGD: onnxblock.optim.SGD}
optim_block = None
if isinstance(optimizer, OptimType):
logging.info("Optimizer enum provided: %s", optimizer.name)
optim_block = optim_blocks[optimizer]()
elif isinstance(optimizer, onnxblock.Block):
logging.info("Optimizer block provided: %s", optimizer.__class__.__name__)
optim_block = optimizer
else:
raise TypeError(
f"Unknown optimizer provided {type(optimizer)}. Expected optimizer to be either one of"
"onnxruntime.training.artifacts.OptimType or onnxruntime.training.onnxblock.Block."
)

optim_block = optim_blocks[optimizer]()
with onnxblock.empty_base(opset_version=opset_version):
_ = optim_block(model_params)
optim_model = optim_block.to_model_proto()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1072,3 +1072,30 @@ def test_save_nominal_checkpoint():
os.stat(os.path.join(temp_dir, "checkpoint")).st_size
> os.stat(os.path.join(temp_dir, "nominal_checkpoint")).st_size
)


def test_custom_optimizer_block():
device = "cpu"
batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10
_, base_model = _get_models(device, batch_size, input_size, hidden_size, output_size)
weight_decay = 123
optimizer = onnxblock.optim.AdamW(weight_decay=weight_decay)

with tempfile.TemporaryDirectory() as temp_dir:
artifacts.generate_artifacts(
base_model,
requires_grad=["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"],
loss=artifacts.LossType.CrossEntropyLoss,
optimizer=optimizer,
artifact_directory=temp_dir,
)

assert os.path.exists(os.path.join(temp_dir, "checkpoint"))
assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx"))

optimizer_model = onnx.load(os.path.join(temp_dir, "optimizer_model.onnx"))
for node in optimizer_model.graph.node:
if node.op_type == "AdamW":
for attr in node.attribute:
if attr.name == "weight_decay":
assert attr.f == weight_decay
71 changes: 10 additions & 61 deletions orttraining/orttraining/training_api/optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,32 +61,19 @@ Status GraphInputsAreExpected(gsl::span<const std::string> actual_graph_inputs,
} // namespace

std::unique_ptr<OptimizerAlgorithmBase> OptimizerAlorithmFactory::CreateInstance(
std::shared_ptr<Model> model, int32_t& group_count) {
const GraphViewer& graph_viewer, int32_t& group_count) {
std::map<std::pair<std::string, std::string>, int32_t> opt_type_to_freq_map;
#if !defined(ORT_MINIMAL_BUILD)
if (model != nullptr) {
Graph& graph = model->MainGraph();
for (auto& node : graph.Nodes()) {
if (node.Domain() == kMSDomain && (node.OpType() == "AdamWOptimizer" || node.OpType() == "SGDOptimizerV2")) {
auto domain_type_pair = std::make_pair(node.Domain(), node.OpType());
if (opt_type_to_freq_map.find(domain_type_pair) == opt_type_to_freq_map.end()) {
opt_type_to_freq_map[domain_type_pair] = 0;
}

opt_type_to_freq_map[domain_type_pair] += 1;
for (const auto& node : graph_viewer.Nodes()) {
if (node.Domain() == kMSDomain && (node.OpType() == "AdamWOptimizer" || node.OpType() == "SGDOptimizerV2")) {
auto domain_type_pair = std::make_pair(node.Domain(), node.OpType());
if (opt_type_to_freq_map.find(domain_type_pair) == opt_type_to_freq_map.end()) {
opt_type_to_freq_map[domain_type_pair] = 0;
}

opt_type_to_freq_map[domain_type_pair] += 1;
}
} else {
#else
ORT_UNUSED_PARAMETER(model);
#endif
// TODO(baijumeswani): Figure out the best way to extract the optimizer type
// from the model (either onnx model or ort format model) or from the checkpoint.
// For now, assume that the optimizer type is AdamWOptimizer when using ort format models.
opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1;
#if !defined(ORT_MINIMAL_BUILD)
}
#endif

ORT_ENFORCE(opt_type_to_freq_map.size() == 1U, "Only support one type of optimizer algorithm, but got: " +
std::to_string(opt_type_to_freq_map.size()));
Expand All @@ -105,42 +92,6 @@ std::unique_ptr<OptimizerAlgorithmBase> OptimizerAlorithmFactory::CreateInstance
}
}

std::unique_ptr<OptimizerAlgorithmBase> OptimizerAlorithmFactory::CreateInstance(
const PathString& optim_path, int32_t& group_count) {
std::shared_ptr<Model> model = nullptr;
#if !defined(ORT_MINIMAL_BUILD)
if (!fbs::utils::IsOrtFormatModel(optim_path)) {
ORT_ENFORCE(Model::Load(optim_path, model, nullptr,
logging::LoggingManager::DefaultLogger())
.IsOK());
}
#else
ORT_UNUSED_PARAMETER(optim_path);
#endif
return CreateInstance(model, group_count);
}

std::unique_ptr<OptimizerAlgorithmBase> OptimizerAlorithmFactory::CreateInstance(
const uint8_t* optim_model_data, size_t optim_model_data_len, int32_t& group_count) {
std::shared_ptr<Model> model = nullptr;
#if !defined(ORT_MINIMAL_BUILD)
if (!fbs::utils::IsOrtFormatModelBytes(optim_model_data, static_cast<int>(optim_model_data_len))) {
ONNX_NAMESPACE::ModelProto model_proto;
ORT_ENFORCE(model_proto.ParseFromArray(optim_model_data, static_cast<int>(optim_model_data_len)) == true,
"Failed to load model because protobuf parsing failed.");

ORT_ENFORCE(Model::Load(std::move(model_proto), model, nullptr,
logging::LoggingManager::DefaultLogger(), ModelOptions(true, true))
.IsOK());
}
#else
ORT_UNUSED_PARAMETER(optim_model_data);
ORT_UNUSED_PARAMETER(optim_model_data_len);
#endif

return CreateInstance(model, group_count);
}

Status Optimizer::GenerateMomentumNamedStates(OptimizerCheckpointState& optimizer_checkpoint_states) {
auto group_optimizer_state_it =
optimizer_checkpoint_states.group_named_optimizer_states.find(GROUP_ZERO_NAME);
Expand Down Expand Up @@ -280,17 +231,15 @@ void Optimizer::Initialize(const ModelIdentifiers& model_identifiers,
auto optimizer_model = std::get<std::optional<std::string>>(model_identifiers.optim_model);
// The above call to IsOptimizerModelAvailable() ensures that optimizer_model is not nullopt
ORT_THROW_IF_ERROR(optim_sess_->Load(optimizer_model.value()));
optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(ToWideString(optimizer_model.value()), group_count_);
} else {
auto optimizer_model = std::get<gsl::span<const uint8_t>>(model_identifiers.optim_model);
ORT_THROW_IF_ERROR(optim_sess_->Load(optimizer_model.data(),
static_cast<int>(optimizer_model.size())));
optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(optimizer_model.data(),
optimizer_model.size(),
group_count_);
}

ORT_THROW_IF_ERROR(optim_sess_->Initialize());
optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(optim_sess_->GetSessionState().GetGraphViewer(),
group_count_);

// Make sure that the checkpoint state can copy tensors
state_->optimizer_checkpoint_state.optimizer_session_data_transfer_mgr = &optim_sess_->GetDataTransferManager();
Expand Down
5 changes: 1 addition & 4 deletions orttraining/orttraining/training_api/optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,8 @@ struct SGDOptimizerV2Algorithm : public OptimizerAlgorithmBase {
};

struct OptimizerAlorithmFactory {
static std::unique_ptr<OptimizerAlgorithmBase> CreateInstance(const PathString& optim_path,
static std::unique_ptr<OptimizerAlgorithmBase> CreateInstance(const GraphViewer& graph_viewer,
int32_t& group_count);
static std::unique_ptr<OptimizerAlgorithmBase> CreateInstance(const uint8_t* optim_model_data,
size_t optim_model_data_len, int32_t& group_count);
static std::unique_ptr<OptimizerAlgorithmBase> CreateInstance(std::shared_ptr<Model> model, int32_t& group_count);
};

struct CheckpointState;
Expand Down

0 comments on commit 226f60f

Please sign in to comment.