Skip to content

Commit

Permalink
Bring QAT POC back to a functional state
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani committed Jan 26, 2024
1 parent 656ca66 commit 04a85dc
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ We use `onnxruntime.training.onnxblock` to perform the above operations to get t

> **_NOTE:_** As of this writing, ORT does not have its own `"Observers"`. Instead, we rely on the `onnxruntime.quantization` tool to quantize the model and give us an initial estimate of the quantization parameters using its calibration process. Here the calibration process is used as a substitute for the observers to present the POC.
> **_NOTE:_** Typically, the weights in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since weights are quantized. However, QAT requires weights and biases to be non quantized. We ensure that the weights have dedicated QDQ pair by passing in the flag AddQDQPairToWeight=True`
> **_NOTE:_** Typically, the weights in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since weights are quantized. However, QAT requires weights and biases to be non quantized. We ensure that the weights have dedicated QDQ pair by passing in the flag `AddQDQPairToWeight=True`
> **_NOTE:_** Typically, the bias term in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since it is quantized as int32 as opposed to int8. So, we disable quantizing the bias term using the flag QuantizeBias=False`
Expand Down
56 changes: 20 additions & 36 deletions orttraining/orttraining/test/python/qat_poc_example/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import onnx
import torch

import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training import artifacts


class MNIST(torch.nn.Module):
Expand Down Expand Up @@ -96,42 +96,26 @@ def create_training_artifacts(model_path, artifacts_dir, model_prefix):
4. The checkpoint file
"""

class MNISTWithLoss(onnxblock.TrainingModel):
def __init__(self):
super().__init__()
self.loss = onnxblock.loss.CrossEntropyLoss()

def build(self, output_name):
return self.loss(output_name)

mnist_with_loss = MNISTWithLoss()
onnx_model, eval_model, optimizer_model = onnx.load(model_path), None, None

# Build the training and eval graphs
logging.info("Using onnxblock to create the training artifacts.")
with onnxblock.onnx_model(onnx_model) as model_accessor:
_ = mnist_with_loss(onnx_model.graph.output[0].name)
eval_model = model_accessor.eval_model

# Build the optimizer graph
optimizer = onnxblock.optim.AdamW()
with onnxblock.onnx_model() as accessor:
_ = optimizer(mnist_with_loss.parameters())
optimizer_model = accessor.model
onnx_model = onnx.load(model_path)

requires_grad = [
param.name
for param in onnx_model.graph.initializer
if (not param.name.endswith("_scale") and not param.name.endswith("_zero_point"))
]
artifacts.generate_artifacts(
onnx_model,
requires_grad=requires_grad,
loss=artifacts.LossType.CrossEntropyLoss,
optimizer=artifacts.OptimType.AdamW,
artifact_directory=artifacts_dir,
prefix=model_prefix,
)

# Create the training artifacts
train_model_path = os.path.join(artifacts_dir, f"{model_prefix}_train.onnx")
logging.info(f"Saving the training model to {train_model_path}.")
onnx.save(onnx_model, train_model_path)
eval_model_path = os.path.join(artifacts_dir, f"{model_prefix}_eval.onnx")
logging.info(f"Saving the eval model to {eval_model_path}.")
onnx.save(eval_model, eval_model_path)
optimizer_model_path = os.path.join(artifacts_dir, f"{model_prefix}_optimizer.onnx")
logging.info(f"Saving the optimizer model to {optimizer_model_path}.")
onnx.save(optimizer_model, optimizer_model_path)
trainable_params, non_trainable_params = mnist_with_loss.parameters()
checkpoint_path = os.path.join(artifacts_dir, f"{model_prefix}_checkpoint.ckpt")
logging.info(f"Saving the checkpoint to {checkpoint_path}.")
onnxblock.save_checkpoint((trainable_params, non_trainable_params), checkpoint_path)
train_model_path = os.path.join(artifacts_dir, f"{model_prefix}training_model.onnx")
eval_model_path = os.path.join(artifacts_dir, f"{model_prefix}eval_model.onnx")
optimizer_model_path = os.path.join(artifacts_dir, f"{model_prefix}optimizer_model.onnx")
checkpoint_path = os.path.join(artifacts_dir, f"{model_prefix}checkpoint")

return train_model_path, eval_model_path, optimizer_model_path, checkpoint_path
2 changes: 1 addition & 1 deletion orttraining/orttraining/test/python/qat_poc_example/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)

logging.info("Preparing the training artifacts for QAT.")
training_model_name = "mnist_qat"
training_model_name = "mnist_qat_"
artifacts_dir = os.path.join(model_dir, "training_artifacts")
utils.makedir(artifacts_dir)
training_artifacts = create_training_artifacts(
Expand Down
18 changes: 5 additions & 13 deletions orttraining/orttraining/test/python/qat_poc_example/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,10 @@ def _train_epoch(model, optimizer, train_loader):
model.train()
cumulative_loss = 0
for data, target in train_loader:
forward_inputs = [
data.reshape(len(data), 784).numpy(),
target.numpy().astype(np.int32),
]
train_loss = model(forward_inputs)
train_loss = model(data.reshape(len(data), 784).numpy(), target.numpy().astype(np.int64))
optimizer.step()
model.lazy_reset_grad()
cumulative_loss += train_loss[0]
cumulative_loss += train_loss

return cumulative_loss / len(train_loader)

Expand All @@ -43,12 +39,8 @@ def _eval(model, test_loader):
model.eval()
cumulative_loss = 0
for data, target in test_loader:
forward_inputs = [
data.reshape(len(data), 784).numpy(),
target.numpy().astype(np.int32),
]
test_loss = model(forward_inputs)
cumulative_loss += test_loss[0]
test_loss = model(data.reshape(len(data), 784).numpy(), target.numpy().astype(np.int64))
cumulative_loss += test_loss

return cumulative_loss / len(test_loader)

Expand All @@ -65,7 +57,7 @@ def train_model(qat_train_model, qat_eval_model, qat_optimizer_model, qat_checkp
train_loader, test_loader = _get_dataloaders("data", batch_size)

# Load the checkpoint state.
state = orttraining.CheckpointState(qat_checkpoint)
state = orttraining.CheckpointState.load_checkpoint(qat_checkpoint)

# Create the training module.
model = orttraining.Module(qat_train_model, state, qat_eval_model)
Expand Down

0 comments on commit 04a85dc

Please sign in to comment.