Skip to content

Commit

Permalink
inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed May 7, 2024
1 parent dc45be4 commit 3cd3f98
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ def save_model(self, out_dir):
# Create ONNX model
ir_model = ir.Model(
ir.Graph(
self.inputs,
self.outputs,
self.names_to_values(self.inputs),
self.names_to_values(self.outputs),
nodes=self.nodes.values(),
initializers=self.initializers,
opset_imports={"": 14, "com.microsoft": 1}
Expand Down Expand Up @@ -372,14 +372,15 @@ def make_node(self, op_type: str, inputs: Sequence[str], outputs: Sequence[str],
self.make_constant(input_name)

# Make node only if it does not already exist
if name not in self.nodes:
if name is None or name not in self.nodes:
input_values = self.names_to_values(inputs)
node = ir.Node(domain, op_type, input_values, attributes=ir_convenience.convert_attributes(kwargs), num_outputs=len(outputs), name=name, doc_string=doc_string)
for val, name in zip(node.outputs, outputs):
val.name = name
for val, name_ in zip(node.outputs, outputs):
val.name = name_
# Register the value to the model
self.values[name] = val
self.nodes[name] = node
self.values[name_] = val
if name is not None:
self.nodes[name] = node
else:
node = self.nodes[name]

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable node is not used.

Expand Down Expand Up @@ -423,6 +424,8 @@ def make_inputs(self):
inputs.append(value_name)
self.make_value_info(value_name, self.input_types["past_key_values.value"], shape=self.input_shapes["past_key_values.value"])

self.inputs = inputs

def make_outputs(self):
# Add model-specific outputs to list of model outputs
outputs = []
Expand Down Expand Up @@ -1832,10 +1835,10 @@ def make_rotary_embedding_caches_subgraph(self):
if_name = f"{basename}/If"
if_cos_cache_output, if_sin_cache_output = "cos_cache", "sin_cache"

cos_cache_large_node = ir.Node("", "Constant", name="/large/cos_cache/Constant", attributes=[ir.AttrTensor(ir.Tensor(cos_cache_large))])
sin_cache_large_node = ir.Node("", "Constant", name="/large/sin_cache/Constant", attributes=[ir.AttrTensor(ir.Tensor(sin_cache_large))])
cos_cache_small_node = ir.Node("", "Constant", name="/small/cos_cache/Constant", attributes=[ir.AttrTensor(ir.Tensor(cos_cache_small))])
sin_cache_small_node = ir.Node("", "Constant", name="/small/sin_cache/Constant", attributes=[ir.AttrTensor(ir.Tensor(sin_cache_small))])
cos_cache_large_node = ir.Node("", "Constant", [], name="/large/cos_cache/Constant", attributes=[ir.AttrTensor("value", ir.Tensor(cos_cache_large))])
sin_cache_large_node = ir.Node("", "Constant", [], name="/large/sin_cache/Constant", attributes=[ir.AttrTensor("value", ir.Tensor(sin_cache_large))])
cos_cache_small_node = ir.Node("", "Constant", [], name="/small/cos_cache/Constant", attributes=[ir.AttrTensor("value", ir.Tensor(cos_cache_small))])
sin_cache_small_node = ir.Node("", "Constant", [], name="/small/sin_cache/Constant", attributes=[ir.AttrTensor("value", ir.Tensor(sin_cache_small))])
self.make_node(
"If",
inputs=[f"{greater_name}/output_0"],
Expand Down

0 comments on commit 3cd3f98

Please sign in to comment.