Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] enable faster rcnn and sd model with oneflow backend #10439

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@ def _format_full_class_name(self, obj: Union[str, type, FunctionType]):

elif isinstance(obj, FunctionType):
module = inspect.getmodule(obj)
obj = f"{module.__name__}.{obj.__qualname__}"
if (
module.__name__ == "torch.nn.functional"
and obj.__qualname__ == "boolean_dispatch.<locals>.fn"
):
obj = f"{module.__name__}.{obj.__name__}"
else:
obj = f"{module.__name__}.{obj.__qualname__}"

assert isinstance(obj, str), f"obj must be str, but got {type(obj)}"

Expand Down
9 changes: 7 additions & 2 deletions python/oneflow/framework/infer_compiler/with_fx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def fx_node_tranform(gm):
# Align this with env setting in `with_oneflow_compile`.
# Otherwise, infererence using PyTorch with OneFlow backend on
# multiple input shapes may crash
os.environ.setdefault("ONEFLOW_RUN_GRAPH_BY_VM", "1")
os.environ.setdefault("ONEFLOW_GRAPH_DELAY_VARIABLE_OP_EXECUTION", "1")
os.environ.setdefault("ONEFLOW_MLIR_CSE", "1")
os.environ.setdefault("ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION", "1")
Expand All @@ -63,16 +62,22 @@ def fx_node_tranform(gm):
os.environ.setdefault("ONEFLOW_MLIR_GROUP_MATMUL_QUANT", "1")

class OfGraph(flow.nn.Graph):
@flow.nn.Graph.with_dynamic_input_shape()
def __init__(self):
super().__init__()
self.fx_md = of_gm
self.config.enable_cudnn_conv_heuristic_search_algo(False)
self.config.allow_fuse_add_to_output(True)

def build(self, *args, **kwargs):
return self.fx_md(*args, **kwargs)
if self.fx_md.training:
return self.fx_md(*args, **kwargs)
with flow.no_grad():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

训练或者推理模式的区分,with flow.no_grad,理论上不应该在这里的build函数中体现,而是在用户模型表达中。对于issue中提到的报错,可以确认一下是不是真的缺少对应的反向算子,通过补充反向算子解决问题。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个我问了开发 fused_multi_head_attention_inference 的俊丞,他说这个算子只实现了前向,没实现反向。如果不在build 里面添加,那要修改 test compile 仓库里面的代码?我测试了只用 model.eval() 无法规避 issue中提到的报错

return self.fx_md(*args, **kwargs)

of_g = OfGraph()
of_g._dynamic_input_graph_cache.set_cache_size(9)
of_g._dynamic_input_graph_cache.enable_shared(True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个参数是不是对应了 compile_from_torch 接口 optionsizedynamic 参数。torch.compile接口参数中有dynamic 参数,我理解应该使用用户传进来的dynamic 参数而不是固定值 Truesize 这里设置为默认的9,可以定义一个常量表示,不使用魔鬼数字。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

基本上是对应的。size 这个确实可以改一下,我给加一个常量。dynamic 这个参数我觉得不用改,一是用户的参数传给了 torch,oneflow backend 拿不到,二是因为 torch compile 这个前端的存在,这里 dynamic 写死为 True 和 设置成用户传的值,两者是等价的。

oneflow_fn = lambda *args, **kwargs: of_g(*args, **kwargs)

return oneflow_fn
Expand Down
20 changes: 17 additions & 3 deletions python/oneflow/framework/infer_compiler/with_oneflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,22 @@ def input_fn(value):
)
else:
output = transformed_fn(*args, **kwargs)
if isinstance(output, tuple):
return tuple(flow.utils.tensor.to_torch(i) for i in output)
return flow.utils.tensor.to_torch(output)

def output_fn(value):
if isinstance(value, flow.Tensor):
return flow.utils.tensor.to_torch(value)
else:
return value

if isinstance(output, (tuple, list, flow._oneflow_internal.TensorTuple)):
return tuple(output_fn(i) for i in output)
elif isinstance(output, dict):
return {k: output_fn(v) for (k, v) in output.items()}
elif isinstance(output, flow.Tensor):
return output_fn(output)
else:
raise NotImplementedError(
f"How to handle {type(output)} output type is not implemented"
)

return wrapped_forward
Loading