Skip to content

Commit

Permalink
Add ONNX Builder
Browse files Browse the repository at this point in the history
  • Loading branch information
NewLandTV committed Feb 20, 2024
1 parent d155527 commit b091728
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Dataset/

# Model
*.pth
*.onnx

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
38 changes: 38 additions & 0 deletions Build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import onnx
import torch
import torch.onnx
from Model import CNN

batchSize = 4

# Load model
model = CNN()

model.load_state_dict(torch.load("Model.pth"))
model.eval()

# Export onnx
x = torch.randn(batchSize, 1, 28, 28, requires_grad = True)
out = model(x)
path = "Model.onnx"

torch.onnx.export(
model,
x,
path,
export_params = True,
opset_version = 10,
do_constant_folding = True,
input_names = ["input"],
output_names = ["output"],
dynamic_axes = {
"input": {
0: "batch_size"
},
"output": {
0: "batch_size"
}
}
)

onnx.save(onnx.shape_inference.infer_shapes(onnx.load(path)), path)
19 changes: 19 additions & 0 deletions Requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
certifi==2024.2.2
charset-normalizer==3.3.2
filelock==3.13.1
fsspec==2024.2.0
idna==3.6
Jinja2==3.1.3
MarkupSafe==2.1.5
mpmath==1.3.0
networkx==3.2.1
numpy==1.26.4
onnx==1.15.0
pillow==10.2.0
protobuf==4.25.3
requests==2.31.0
sympy==1.12
torch==2.2.0
torchvision==0.17.0
typing_extensions==4.9.0
urllib3==2.2.1

0 comments on commit b091728

Please sign in to comment.