From 04a85dc98a5fac026f8049aeffe0f752288345c1 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Fri, 26 Jan 2024 23:36:09 +0000 Subject: [PATCH] Bring QAT POC back to a functional state --- .../test/python/qat_poc_example/README.md | 2 +- .../test/python/qat_poc_example/model.py | 56 +++++++------------ .../test/python/qat_poc_example/qat.py | 2 +- .../test/python/qat_poc_example/train.py | 18 ++---- 4 files changed, 27 insertions(+), 51 deletions(-) diff --git a/orttraining/orttraining/test/python/qat_poc_example/README.md b/orttraining/orttraining/test/python/qat_poc_example/README.md index 6840e98bd9c86..05072b410b730 100644 --- a/orttraining/orttraining/test/python/qat_poc_example/README.md +++ b/orttraining/orttraining/test/python/qat_poc_example/README.md @@ -48,7 +48,7 @@ We use `onnxruntime.training.onnxblock` to perform the above operations to get t > **_NOTE:_** As of this writing, ORT does not have its own `"Observers"`. Instead, we rely on the `onnxruntime.quantization` tool to quantize the model and give us an initial estimate of the quantization parameters using its calibration process. Here the calibration process is used as a substitute for the observers to present the POC. -> **_NOTE:_** Typically, the weights in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since weights are quantized. However, QAT requires weights and biases to be non quantized. We ensure that the weights have dedicated QDQ pair by passing in the flag AddQDQPairToWeight=True` +> **_NOTE:_** Typically, the weights in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since weights are quantized. However, QAT requires weights and biases to be non quantized. We ensure that the weights have dedicated QDQ pair by passing in the flag `AddQDQPairToWeight=True` > **_NOTE:_** Typically, the bias term in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since it is quantized as int32 as opposed to int8. So, we disable quantizing the bias term using the flag QuantizeBias=False` diff --git a/orttraining/orttraining/test/python/qat_poc_example/model.py b/orttraining/orttraining/test/python/qat_poc_example/model.py index 91d7ccd7294f5..601362a59e379 100644 --- a/orttraining/orttraining/test/python/qat_poc_example/model.py +++ b/orttraining/orttraining/test/python/qat_poc_example/model.py @@ -5,7 +5,7 @@ import onnx import torch -import onnxruntime.training.onnxblock as onnxblock +from onnxruntime.training import artifacts class MNIST(torch.nn.Module): @@ -96,42 +96,26 @@ def create_training_artifacts(model_path, artifacts_dir, model_prefix): 4. The checkpoint file """ - class MNISTWithLoss(onnxblock.TrainingModel): - def __init__(self): - super().__init__() - self.loss = onnxblock.loss.CrossEntropyLoss() - - def build(self, output_name): - return self.loss(output_name) - - mnist_with_loss = MNISTWithLoss() - onnx_model, eval_model, optimizer_model = onnx.load(model_path), None, None - - # Build the training and eval graphs - logging.info("Using onnxblock to create the training artifacts.") - with onnxblock.onnx_model(onnx_model) as model_accessor: - _ = mnist_with_loss(onnx_model.graph.output[0].name) - eval_model = model_accessor.eval_model - - # Build the optimizer graph - optimizer = onnxblock.optim.AdamW() - with onnxblock.onnx_model() as accessor: - _ = optimizer(mnist_with_loss.parameters()) - optimizer_model = accessor.model + onnx_model = onnx.load(model_path) + + requires_grad = [ + param.name + for param in onnx_model.graph.initializer + if (not param.name.endswith("_scale") and not param.name.endswith("_zero_point")) + ] + artifacts.generate_artifacts( + onnx_model, + requires_grad=requires_grad, + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + artifact_directory=artifacts_dir, + prefix=model_prefix, + ) # Create the training artifacts - train_model_path = os.path.join(artifacts_dir, f"{model_prefix}_train.onnx") - logging.info(f"Saving the training model to {train_model_path}.") - onnx.save(onnx_model, train_model_path) - eval_model_path = os.path.join(artifacts_dir, f"{model_prefix}_eval.onnx") - logging.info(f"Saving the eval model to {eval_model_path}.") - onnx.save(eval_model, eval_model_path) - optimizer_model_path = os.path.join(artifacts_dir, f"{model_prefix}_optimizer.onnx") - logging.info(f"Saving the optimizer model to {optimizer_model_path}.") - onnx.save(optimizer_model, optimizer_model_path) - trainable_params, non_trainable_params = mnist_with_loss.parameters() - checkpoint_path = os.path.join(artifacts_dir, f"{model_prefix}_checkpoint.ckpt") - logging.info(f"Saving the checkpoint to {checkpoint_path}.") - onnxblock.save_checkpoint((trainable_params, non_trainable_params), checkpoint_path) + train_model_path = os.path.join(artifacts_dir, f"{model_prefix}training_model.onnx") + eval_model_path = os.path.join(artifacts_dir, f"{model_prefix}eval_model.onnx") + optimizer_model_path = os.path.join(artifacts_dir, f"{model_prefix}optimizer_model.onnx") + checkpoint_path = os.path.join(artifacts_dir, f"{model_prefix}checkpoint") return train_model_path, eval_model_path, optimizer_model_path, checkpoint_path diff --git a/orttraining/orttraining/test/python/qat_poc_example/qat.py b/orttraining/orttraining/test/python/qat_poc_example/qat.py index 51a15475ee911..dcc9e116fda7d 100644 --- a/orttraining/orttraining/test/python/qat_poc_example/qat.py +++ b/orttraining/orttraining/test/python/qat_poc_example/qat.py @@ -46,7 +46,7 @@ ) logging.info("Preparing the training artifacts for QAT.") - training_model_name = "mnist_qat" + training_model_name = "mnist_qat_" artifacts_dir = os.path.join(model_dir, "training_artifacts") utils.makedir(artifacts_dir) training_artifacts = create_training_artifacts( diff --git a/orttraining/orttraining/test/python/qat_poc_example/train.py b/orttraining/orttraining/test/python/qat_poc_example/train.py index 9a429d2adc6f1..a25c071c58a48 100644 --- a/orttraining/orttraining/test/python/qat_poc_example/train.py +++ b/orttraining/orttraining/test/python/qat_poc_example/train.py @@ -26,14 +26,10 @@ def _train_epoch(model, optimizer, train_loader): model.train() cumulative_loss = 0 for data, target in train_loader: - forward_inputs = [ - data.reshape(len(data), 784).numpy(), - target.numpy().astype(np.int32), - ] - train_loss = model(forward_inputs) + train_loss = model(data.reshape(len(data), 784).numpy(), target.numpy().astype(np.int64)) optimizer.step() model.lazy_reset_grad() - cumulative_loss += train_loss[0] + cumulative_loss += train_loss return cumulative_loss / len(train_loader) @@ -43,12 +39,8 @@ def _eval(model, test_loader): model.eval() cumulative_loss = 0 for data, target in test_loader: - forward_inputs = [ - data.reshape(len(data), 784).numpy(), - target.numpy().astype(np.int32), - ] - test_loss = model(forward_inputs) - cumulative_loss += test_loss[0] + test_loss = model(data.reshape(len(data), 784).numpy(), target.numpy().astype(np.int64)) + cumulative_loss += test_loss return cumulative_loss / len(test_loader) @@ -65,7 +57,7 @@ def train_model(qat_train_model, qat_eval_model, qat_optimizer_model, qat_checkp train_loader, test_loader = _get_dataloaders("data", batch_size) # Load the checkpoint state. - state = orttraining.CheckpointState(qat_checkpoint) + state = orttraining.CheckpointState.load_checkpoint(qat_checkpoint) # Create the training module. model = orttraining.Module(qat_train_model, state, qat_eval_model)