Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: auto split onnx model (https://github.com/zkonduit/ezkl/discussions/744) #855

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ExcellentHH
Copy link

A simple and naive solution to a discussion topic (#744) I raised.

Script Explanation

This script addresses the challenge of generating correctness proofs for large ONNX models on machines with limited hardware capabilities. For instance, if a machine has a processing constraint of (2^{24}) but the model requires significantly more, it becomes difficult to handle.

To overcome this, the script automatically partitions a large model into multiple smaller sub-models based on a given upper threshold. It ensures that the intermediate results between sub-models are protected through hashing for privacy.

By splitting the large model, this approach enables verification of larger models on machines with average hardware. Additionally, it facilitates parallel validation of the models by allowing multiple sub-models to be validated simultaneously using multithreading or multiple machines, thus improving overall efficiency.

@alexander-camuto
Copy link
Collaborator

@ExcellentHH thanks so much for getting this over the finish line. Will review this in a bit but first off huge congratulations on seeing this through 🎉🎉🎉🎉🍾🍾🍾🍾🍾

@jasonmorton
Copy link
Member

Great stuff indeed! I think we're going to want to change to KZG commitments rather than hashes to save rows, and add a final loop to actually compute the proofs, glue, and verify in an integration test.

@JSeam2
Copy link
Collaborator

JSeam2 commented Oct 27, 2024

Amazing work! One usability caveat that might potentially cause issues are networks with recurrent structures.

The following line should deal fine with DAG type of networks

    # Topological sorting
    topo_sorted_nodes = list(nx.topological_sort(G))

For better usability it might be worth flagging cycles in the network to users, and provide an error message saying the scheme will not support such kinds of networks.

Copy link
Collaborator

@alexander-camuto alexander-camuto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great start for this. I've added some comments.

It would also be great to add an integration test for all this that splits a smaller model like nanoGPT into 2^17 chunks, generates the witness for each one, proves all the chunks, then verifies that the proof commitments match using the method from the example notebooks :)

Would also note that we need to ensure that the input_scale for subgraphs with index > 0 is the same as the output scale of the prior subgraph. You can see how we do this in the example notebooks. Rn fwiw it is calibrating over all scales. This will also shorten runtime significantly.

There are some other subtleties to discuss but lets resolve these issues first. If you don't have bandwidth to solve all of these lmk and I can help you get it over the line

help="Input shape for the ONNX model in JSON format. Default is '{\"input\": [1, 3, 224, 224]}'.")
parser.add_argument("--simplify", action='store_true',
help="Flag to indicate if the model should be simplified. Default is False.")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add an argument to specify the visibility of the stitched commitments eg. public, hashed, polycommit

parser.add_argument("--simplify", action='store_true',
help="Flag to indicate if the model should be simplified. Default is False.")

args = parser.parse_args()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a flag to (optionally) compile the circuit (or not) -- but if the upper_bound_per_subgraph is 23 logrows for eg. it should 1. calibrate 2. reduce the logrows to 23 manually 3. compile

is_pass = False
with open(json_file, 'r') as f:
data = json.load(f)
total_assignments = data.get("total_assignments", 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be checking num_rows not total_assignements


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process ONNX model and generate subgraphs.")
parser.add_argument("--onnx_model_path", type=str, default='./resnet18.onnx', help="Path to the ONNX model. Default is './resnet18.onnx'.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we change the default to match the library's default of network.onnx

res = ezkl.gen_settings(temp_model_name, py_run_args=run_args)
assert res == True

data_path = f"input_data_{subgraph_index}.json"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something I realized would be super helpful to make this more useable, is to create a separate script that takes the partitioned onnx files and generates a new set of intermediate / subgraph input json files given an orginal file.

So for eg.

If I have input.json and subgraph1.onnx,subgraph2.onnx, subgraph3.onnx

the script generate_subgraph_input.json should yield input files for each subgraph:

input_data_1.json, input_data_2.json, input_data_3.json.

Even if input_data_1.json = input.json would still be useful to make it explicit.

Just came to this realization cause I ended up having to code this up and would be useful

@ExcellentHH
Copy link
Author

@alexander-camuto @jasonmorton @JSeam2 Thank you for the response and encouragement. Special thanks to @alexander-camuto for the improvement suggestions—I’ve learned a lot from them, and I apologize for the immaturity of my code. I’ve been a bit busy recently, but I’ll work on the code improvements as soon as possible. Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants