Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
pengwa committed Dec 25, 2023
1 parent 4e1859e commit 3c3b4bf
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
# --------------------------------------------------------------------------


from typing import Dict, List, Optional, Tuple, Union
import ctypes
from typing import Dict, List, Optional, Tuple, Union

import torch
from onnx import ModelProto, NodeProto, TensorProto, helper

from onnxruntime.training.utils import pytorch_type_to_onnx_dtype

from ._pythonop_helper import make_pythonop_node
Expand Down Expand Up @@ -45,19 +47,23 @@ def post_processing_enable_mem_efficient_training(
def _get_param_pull_trigger_name(param_name: str) -> str:
return f"pull_{param_name}"


# Create weight retrieving PythonOp.
inputs = [helper.make_tensor_value_info(
MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE, # Use the same data type with output for the input
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
)]

outputs = [helper.make_tensor_value_info(
_get_param_pull_trigger_name(pname),
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
) for pname in trainable_named_params]
inputs = [
helper.make_tensor_value_info(
MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE, # Use the same data type with output for the input
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
)
]

outputs = [
helper.make_tensor_value_info(
_get_param_pull_trigger_name(pname),
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
)
for pname in trainable_named_params
]

weight_pull_node = make_pythonop_node(
"weight_pull_trigger",
Expand All @@ -68,33 +74,33 @@ def _get_param_pull_trigger_name(param_name: str) -> str:
safe_run_mode=0,
)


graph_inputs_to_remove = []
for graph_input in reversed(exported_model.graph.input):
if graph_input.name not in trainable_named_params:
continue


graph_inputs_to_remove.append(graph_input)

if graph_input.name not in consumer_map:
continue

# Create the param retrieval function for this parameter.
node_inputs = [helper.make_tensor_value_info(
node_inputs = [
helper.make_tensor_value_info(
_get_param_pull_trigger_name(graph_input.name),
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
), graph_input.name, # Second param is a string, which represent the param_name
),
graph_input.name, # Second param is a string, which represent the param_name
]

node_outputs = [
helper.make_tensor_value_info(
graph_input.name, # output use the same name as weight
int(pytorch_type_to_onnx_dtype(trainable_named_params[graph_input.name].dtype)),
list(trainable_named_params[graph_input.name].shape),
),]

graph_input.name, # output use the same name as weight
int(pytorch_type_to_onnx_dtype(trainable_named_params[graph_input.name].dtype)),
list(trainable_named_params[graph_input.name].shape),
),
]

new_node = make_pythonop_node(
f"weight_retrieval_{graph_input.name}",
Expand All @@ -114,13 +120,10 @@ def _get_param_pull_trigger_name(param_name: str) -> str:
exported_model.graph.input.insert(0, inputs[0])
exported_model.graph.node.insert(0, weight_pull_node)


return exported_model


def _create_param_trigger_function(
trainable_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]]
):
def _create_param_trigger_function(trainable_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]]):
"""This function is used to create a weight retrieving function using trainable_named_params."""

class ParamTriggerFunction(torch.autograd.Function):
Expand Down Expand Up @@ -153,21 +156,10 @@ def infer_shape(

return tensor_output_shapes, tensor_output_dtypes




# func_full_qual_name = get_fully_qualified_class_name(ParamTriggerFunction)
# register_torch_autograd_function(func_full_qual_name, ParamTriggerFunction)
# register_custom_function_schema_supplementary(ParamTriggerFunction)

# return func_full_qual_name

return ParamTriggerFunction


def _create_param_retrieval_function(
trainable_named_params: Dict[str, torch.nn.parameter.Parameter]
):
def _create_param_retrieval_function(trainable_named_params: Dict[str, torch.nn.parameter.Parameter]):
"""This function is used to create a weight retrieving function using trainable_named_params."""

class ParamRetrievalFunction(torch.autograd.Function):
Expand Down Expand Up @@ -199,14 +191,13 @@ def infer_shape(
# Restore the nn.Module from the pointer.
param_name = ctypes.cast(input_pointer_scalars[0], ctypes.py_object).value

tensor_output_shapes = [list(trainable_named_params[param_name].shape),]
tensor_output_dtypes = [int(pytorch_type_to_onnx_dtype(trainable_named_params[param_name].dtype)),]
tensor_output_shapes = [
list(trainable_named_params[param_name].shape),
]
tensor_output_dtypes = [
int(pytorch_type_to_onnx_dtype(trainable_named_params[param_name].dtype)),
]

return tensor_output_shapes, tensor_output_dtypes

# func_full_qual_name = get_fully_qualified_class_name(ParamRetrievalFunction)
# register_torch_autograd_function(func_full_qual_name, ParamRetrievalFunction)
# register_custom_function_schema_supplementary(ParamRetrievalFunction)

# return func_full_qual_name
return ParamRetrievalFunction
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,16 @@
# --------------------------------------------------------------------------

from __future__ import annotations

import inspect

import onnx
import torch

from onnxruntime.capi._pybind_state import (
register_miscellaneous_const_input,
register_torch_autograd_function,
)
from onnxruntime.capi._pybind_state import register_miscellaneous_const_input, register_torch_autograd_function

from ._utils import get_fully_qualified_class_name
from ._custom_autograd_function_exporter import register_custom_function_schema_supplementary


import onnx
from ._utils import get_fully_qualified_class_name

PYTHON_OP_DOMAIN = "com.microsoft"
PYTHON_OP_TYPE = "PythonOp"
Expand Down Expand Up @@ -46,17 +42,20 @@ def set_safe_run_mode(model: onnx.ModelProto, allowed_unsafe_run_python_op_names

return model


_PYTHON_OP_INCRE_INDEX = [0]


def make_pythonop_node(
name_prefix: str,
inputs: list[onnx.ValueInfoProto | int | bool | float | tuple[int, ...] | tuple[bool, ...] | tuple[float, ...] | object ],
inputs: list[
onnx.ValueInfoProto | int | bool | float | tuple[int, ...] | tuple[bool, ...] | tuple[float, ...] | object
],
outputs: list[onnx.ValueInfoProto],
func_class: torch.autograd.Function,
training_mode:int,
safe_run_mode:int,
) -> onnx.NodeProto:

training_mode: int,
safe_run_mode: int,
) -> onnx.NodeProto:
assert issubclass(func_class, torch.autograd.Function), "func_class must be a subclass of torch.autograd.Function."

assert len(inputs) > 0, f"inputs must not be empty for function {func_class}."
Expand All @@ -65,7 +64,10 @@ def make_pythonop_node(
all_input_parameters: list[inspect.Parameter] = list(inspect.signature(func_class.forward).parameters.values())

# Remove the first parameter (ctx) from inspected parameter list.
assert len(inputs) == len(all_input_parameters) - 1, f"The number of inputs ({len(inputs)}) must match the number of parameters ({len(all_input_parameters) - 1}) of the forward function."
assert len(inputs) == len(all_input_parameters) - 1, (
f"The number of inputs ({len(inputs)}) must match the number of parameters "
f"({len(all_input_parameters) - 1}) of the forward function."
)

func_full_qual_name = get_fully_qualified_class_name(func_class)

Expand Down Expand Up @@ -106,10 +108,10 @@ def make_pythonop_node(
tensor_args.append(arg.name)
input_tensor_types.append(arg.type.tensor_type.elem_type)
input_tensor_ranks.append(len(arg.type.tensor_type.shape.dim))
cconv += 'd'
cconv += "d"
continue

cconv += 'c'
cconv += "c"

# Got a non-tensor variable.
if isinstance(arg, float):
Expand Down Expand Up @@ -180,7 +182,6 @@ def make_pythonop_node(
output_tensor_types.append(arg.type.tensor_type.elem_type)
output_tensor_ranks.append(len(arg.type.tensor_type.shape.dim))


attrs = {
"func_name": func_full_qual_name,
"input_convention": cconv,
Expand Down Expand Up @@ -219,7 +220,6 @@ def make_pythonop_node(
attrs["input_pointer_scalars"] = input_pointer_scalars
attrs["input_pointer_scalar_positions"] = input_pointer_scalar_positions


# Register function with class names.
register_torch_autograd_function(func_full_qual_name, func_class)

Expand Down

0 comments on commit 3c3b4bf

Please sign in to comment.