Skip to content

Commit

Permalink
Remove the requirement to push the whole Framework Module as input to (
Browse files Browse the repository at this point in the history
…#591)

the ForgeModule during compile time. Previously required for parameter
initialization.
- This change is required to remove added constraints during the op test
  generation
- This change cleans the generated module a bit and does serialization
  of
  all required parameters that are just load within the model during
compile time
- This chain represents one of the steps to generate single op tests
  with
  focus on a specific model

Fix #590
  • Loading branch information
nvukobratTT authored Nov 5, 2024
1 parent f5b51b6 commit 8e76b59
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
12 changes: 7 additions & 5 deletions forge/forge/python_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,17 +259,19 @@ def write_forward(self, ops, inputs, outputs):
self.indent = 0
self.wl("")

def write_param_parser(self, param_names, param_file_name):
def write_param_parser(
self, param_names, param_file_name, names_params_file_name=None, named_buffers_file_name=None
):
self.indent = 1

if self.framework == "pytorch":
self.wl(f"def process_framework_parameters(self, model):")
self.wl(f"def process_framework_parameters(self):")
self.indent += 1
self.wl(f"named_parameters = dict(model.state_dict().items())")
self.wl(f"named_parameters = torch.load('{names_params_file_name}')")
if param_file_name is not None:
self.wl(f'serialized_params = torch.load("{param_file_name}")')
self.wl(f"named_parameters.update(serialized_params)")
self.wl("named_buffers = dict(model.named_buffers())")
self.wl(f"named_buffers = torch.load('{named_buffers_file_name}')")
self.wl("named_parameters.update(named_buffers)")

if len(param_names):
Expand Down Expand Up @@ -1249,7 +1251,7 @@ def write_param_parser(self, param_names, param_file_name):
self.indent = 1

if self.framework == "pytorch":
self.wl(f"def process_framework_parameters(self, model):")
self.wl(f"def process_framework_parameters(self):")
self.indent += 1

self.wl("named_parameters = dict(model.named_parameters())")
Expand Down
27 changes: 24 additions & 3 deletions forge/forge/tvm_to_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -2039,7 +2039,12 @@ def generate_forge_module(
forge_mod.module.process_framework_parameters(framework_mod.module)
else:
forge_mod = TestClass(writer.module_name)
forge_mod.process_framework_parameters(framework_mod.module)

if isinstance(framework_mod, forge.module.PyTorchModule):
forge_mod.process_framework_parameters()
else:
forge_mod.process_framework_parameters(framework_mod.module)

assert not any(
[param.value() is None for param in forge_mod.get_parameters()]
), f"Could not retrieve parameters from framework and tvm"
Expand Down Expand Up @@ -2655,8 +2660,24 @@ def delete_unneeded_outputs(ops, returns):
param_file_name = os.path.join(writer.module_directory, writer.module_name + "_params.pt")
torch.save(params_from_tvm, param_file_name)

param_names.update(const_names)
writer.write_param_parser(param_names, param_file_name)
if framework == "pytorch":
# Store named parameters
names_params_file_name = os.path.join(writer.module_directory, writer.module_name + "_names_params.pt")
named_parameters = dict(framework_mod.module.state_dict().items())
torch.save(named_parameters, names_params_file_name)

# Store named buffers
named_buffers_file_name = os.path.join(writer.module_directory, writer.module_name + "_named_buffers.pt")
named_buffers = dict(framework_mod.module.named_buffers())
torch.save(named_buffers, named_buffers_file_name)

# Generate Forge module parameter parser
param_names.update(const_names)
writer.write_param_parser(param_names, param_file_name, names_params_file_name, named_buffers_file_name)
else:
param_names.update(const_names)
writer.write_param_parser(param_names, param_file_name)

writer.close_file()

modules.append(writer)
Expand Down

0 comments on commit 8e76b59

Please sign in to comment.