Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring QAT POC back to a functional state #19290

Merged
merged 1 commit into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ We use `onnxruntime.training.onnxblock` to perform the above operations to get t

> **_NOTE:_** As of this writing, ORT does not have its own `"Observers"`. Instead, we rely on the `onnxruntime.quantization` tool to quantize the model and give us an initial estimate of the quantization parameters using its calibration process. Here the calibration process is used as a substitute for the observers to present the POC.
> **_NOTE:_** Typically, the weights in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since weights are quantized. However, QAT requires weights and biases to be non quantized. We ensure that the weights have dedicated QDQ pair by passing in the flag AddQDQPairToWeight=True`
> **_NOTE:_** Typically, the weights in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since weights are quantized. However, QAT requires weights and biases to be non quantized. We ensure that the weights have dedicated QDQ pair by passing in the flag `AddQDQPairToWeight=True`
> **_NOTE:_** Typically, the bias term in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since it is quantized as int32 as opposed to int8. So, we disable quantizing the bias term using the flag QuantizeBias=False`
Expand Down
56 changes: 20 additions & 36 deletions orttraining/orttraining/test/python/qat_poc_example/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import onnx
import torch

import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training import artifacts


class MNIST(torch.nn.Module):
Expand Down Expand Up @@ -96,42 +96,26 @@ def create_training_artifacts(model_path, artifacts_dir, model_prefix):
4. The checkpoint file
"""

class MNISTWithLoss(onnxblock.TrainingModel):
def __init__(self):
super().__init__()
self.loss = onnxblock.loss.CrossEntropyLoss()

def build(self, output_name):
return self.loss(output_name)

mnist_with_loss = MNISTWithLoss()
onnx_model, eval_model, optimizer_model = onnx.load(model_path), None, None

# Build the training and eval graphs
logging.info("Using onnxblock to create the training artifacts.")
with onnxblock.onnx_model(onnx_model) as model_accessor:
_ = mnist_with_loss(onnx_model.graph.output[0].name)
eval_model = model_accessor.eval_model

# Build the optimizer graph
optimizer = onnxblock.optim.AdamW()
with onnxblock.onnx_model() as accessor:
_ = optimizer(mnist_with_loss.parameters())
optimizer_model = accessor.model
onnx_model = onnx.load(model_path)

requires_grad = [
param.name
for param in onnx_model.graph.initializer
if (not param.name.endswith("_scale") and not param.name.endswith("_zero_point"))
]
artifacts.generate_artifacts(
onnx_model,
requires_grad=requires_grad,
loss=artifacts.LossType.CrossEntropyLoss,
optimizer=artifacts.OptimType.AdamW,
artifact_directory=artifacts_dir,
prefix=model_prefix,
)

# Create the training artifacts
train_model_path = os.path.join(artifacts_dir, f"{model_prefix}_train.onnx")
logging.info(f"Saving the training model to {train_model_path}.")
onnx.save(onnx_model, train_model_path)
eval_model_path = os.path.join(artifacts_dir, f"{model_prefix}_eval.onnx")
logging.info(f"Saving the eval model to {eval_model_path}.")
onnx.save(eval_model, eval_model_path)
optimizer_model_path = os.path.join(artifacts_dir, f"{model_prefix}_optimizer.onnx")
logging.info(f"Saving the optimizer model to {optimizer_model_path}.")
onnx.save(optimizer_model, optimizer_model_path)
trainable_params, non_trainable_params = mnist_with_loss.parameters()
checkpoint_path = os.path.join(artifacts_dir, f"{model_prefix}_checkpoint.ckpt")
logging.info(f"Saving the checkpoint to {checkpoint_path}.")
onnxblock.save_checkpoint((trainable_params, non_trainable_params), checkpoint_path)
train_model_path = os.path.join(artifacts_dir, f"{model_prefix}training_model.onnx")
eval_model_path = os.path.join(artifacts_dir, f"{model_prefix}eval_model.onnx")
optimizer_model_path = os.path.join(artifacts_dir, f"{model_prefix}optimizer_model.onnx")
checkpoint_path = os.path.join(artifacts_dir, f"{model_prefix}checkpoint")

return train_model_path, eval_model_path, optimizer_model_path, checkpoint_path
2 changes: 1 addition & 1 deletion orttraining/orttraining/test/python/qat_poc_example/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)

logging.info("Preparing the training artifacts for QAT.")
training_model_name = "mnist_qat"
training_model_name = "mnist_qat_"
artifacts_dir = os.path.join(model_dir, "training_artifacts")
utils.makedir(artifacts_dir)
training_artifacts = create_training_artifacts(
Expand Down
18 changes: 5 additions & 13 deletions orttraining/orttraining/test/python/qat_poc_example/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,10 @@ def _train_epoch(model, optimizer, train_loader):
model.train()
cumulative_loss = 0
for data, target in train_loader:
forward_inputs = [
data.reshape(len(data), 784).numpy(),
target.numpy().astype(np.int32),
]
train_loss = model(forward_inputs)
train_loss = model(data.reshape(len(data), 784).numpy(), target.numpy().astype(np.int64))
optimizer.step()
model.lazy_reset_grad()
cumulative_loss += train_loss[0]
cumulative_loss += train_loss

return cumulative_loss / len(train_loader)

Expand All @@ -43,12 +39,8 @@ def _eval(model, test_loader):
model.eval()
cumulative_loss = 0
for data, target in test_loader:
forward_inputs = [
data.reshape(len(data), 784).numpy(),
target.numpy().astype(np.int32),
]
test_loss = model(forward_inputs)
cumulative_loss += test_loss[0]
test_loss = model(data.reshape(len(data), 784).numpy(), target.numpy().astype(np.int64))
cumulative_loss += test_loss

return cumulative_loss / len(test_loader)

Expand All @@ -65,7 +57,7 @@ def train_model(qat_train_model, qat_eval_model, qat_optimizer_model, qat_checkp
train_loader, test_loader = _get_dataloaders("data", batch_size)

# Load the checkpoint state.
state = orttraining.CheckpointState(qat_checkpoint)
state = orttraining.CheckpointState.load_checkpoint(qat_checkpoint)

# Create the training module.
model = orttraining.Module(qat_train_model, state, qat_eval_model)
Expand Down
Loading