diff --git a/python/oneflow/framework/infer_compiler/import_tools/format_utils.py b/python/oneflow/framework/infer_compiler/import_tools/format_utils.py index d963426b140..d14026794bd 100644 --- a/python/oneflow/framework/infer_compiler/import_tools/format_utils.py +++ b/python/oneflow/framework/infer_compiler/import_tools/format_utils.py @@ -40,7 +40,13 @@ def _format_full_class_name(self, obj: Union[str, type, FunctionType]): elif isinstance(obj, FunctionType): module = inspect.getmodule(obj) - obj = f"{module.__name__}.{obj.__qualname__}" + if ( + module.__name__ == "torch.nn.functional" + and obj.__qualname__ == "boolean_dispatch..fn" + ): + obj = f"{module.__name__}.{obj.__name__}" + else: + obj = f"{module.__name__}.{obj.__qualname__}" assert isinstance(obj, str), f"obj must be str, but got {type(obj)}" diff --git a/python/oneflow/framework/infer_compiler/with_fx_graph.py b/python/oneflow/framework/infer_compiler/with_fx_graph.py index 881b720793c..3cea89ad8e3 100644 --- a/python/oneflow/framework/infer_compiler/with_fx_graph.py +++ b/python/oneflow/framework/infer_compiler/with_fx_graph.py @@ -38,7 +38,6 @@ def fx_node_tranform(gm): # Align this with env setting in `with_oneflow_compile`. # Otherwise, infererence using PyTorch with OneFlow backend on # multiple input shapes may crash - os.environ.setdefault("ONEFLOW_RUN_GRAPH_BY_VM", "1") os.environ.setdefault("ONEFLOW_GRAPH_DELAY_VARIABLE_OP_EXECUTION", "1") os.environ.setdefault("ONEFLOW_MLIR_CSE", "1") os.environ.setdefault("ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION", "1") @@ -63,6 +62,7 @@ def fx_node_tranform(gm): os.environ.setdefault("ONEFLOW_MLIR_GROUP_MATMUL_QUANT", "1") class OfGraph(flow.nn.Graph): + @flow.nn.Graph.with_dynamic_input_shape() def __init__(self): super().__init__() self.fx_md = of_gm @@ -70,9 +70,15 @@ def __init__(self): self.config.allow_fuse_add_to_output(True) def build(self, *args, **kwargs): - return self.fx_md(*args, **kwargs) + if self.fx_md.training: + return self.fx_md(*args, **kwargs) + with flow.no_grad(): + return self.fx_md(*args, **kwargs) of_g = OfGraph() + DEFAULT_CACHE_SIZE = 9 + of_g._dynamic_input_graph_cache.set_cache_size(DEFAULT_CACHE_SIZE) + of_g._dynamic_input_graph_cache.enable_shared(True) oneflow_fn = lambda *args, **kwargs: of_g(*args, **kwargs) return oneflow_fn diff --git a/python/oneflow/framework/infer_compiler/with_oneflow_backend.py b/python/oneflow/framework/infer_compiler/with_oneflow_backend.py index 23fcb5aa684..cc5c2012765 100644 --- a/python/oneflow/framework/infer_compiler/with_oneflow_backend.py +++ b/python/oneflow/framework/infer_compiler/with_oneflow_backend.py @@ -46,8 +46,22 @@ def input_fn(value): ) else: output = transformed_fn(*args, **kwargs) - if isinstance(output, tuple): - return tuple(flow.utils.tensor.to_torch(i) for i in output) - return flow.utils.tensor.to_torch(output) + + def output_fn(value): + if isinstance(value, flow.Tensor): + return flow.utils.tensor.to_torch(value) + else: + return value + + if isinstance(output, (tuple, list, flow._oneflow_internal.TensorTuple)): + return tuple(output_fn(i) for i in output) + elif isinstance(output, dict): + return {k: output_fn(v) for (k, v) in output.items()} + elif isinstance(output, flow.Tensor): + return output_fn(output) + else: + raise NotImplementedError( + f"How to handle {type(output)} output type is not implemented" + ) return wrapped_forward