diff --git a/.gitignore b/.gitignore index f70e504..e85883e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ Dataset/ # Model *.pth +*.onnx # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/Build.py b/Build.py new file mode 100644 index 0000000..f203aea --- /dev/null +++ b/Build.py @@ -0,0 +1,38 @@ +import onnx +import torch +import torch.onnx +from Model import CNN + +batchSize = 4 + +# Load model +model = CNN() + +model.load_state_dict(torch.load("Model.pth")) +model.eval() + +# Export onnx +x = torch.randn(batchSize, 1, 28, 28, requires_grad = True) +out = model(x) +path = "Model.onnx" + +torch.onnx.export( + model, + x, + path, + export_params = True, + opset_version = 10, + do_constant_folding = True, + input_names = ["input"], + output_names = ["output"], + dynamic_axes = { + "input": { + 0: "batch_size" + }, + "output": { + 0: "batch_size" + } + } +) + +onnx.save(onnx.shape_inference.infer_shapes(onnx.load(path)), path) diff --git a/Requirements.txt b/Requirements.txt new file mode 100644 index 0000000..b4e602f --- /dev/null +++ b/Requirements.txt @@ -0,0 +1,19 @@ +certifi==2024.2.2 +charset-normalizer==3.3.2 +filelock==3.13.1 +fsspec==2024.2.0 +idna==3.6 +Jinja2==3.1.3 +MarkupSafe==2.1.5 +mpmath==1.3.0 +networkx==3.2.1 +numpy==1.26.4 +onnx==1.15.0 +pillow==10.2.0 +protobuf==4.25.3 +requests==2.31.0 +sympy==1.12 +torch==2.2.0 +torchvision==0.17.0 +typing_extensions==4.9.0 +urllib3==2.2.1