Skip to content

Commit

Permalink
Update on "[executorch][serialization] Move DataSegment into shared c…
Browse files Browse the repository at this point in the history
…ommon.fbs"

So that `DataSegment` can be shared by incoming data.fbs.

Differential Revision: [D65434369](https://our.internmc.facebook.com/intern/diff/D65434369/)

[ghstack-poisoned]
  • Loading branch information
lucylq committed Nov 12, 2024
2 parents 03626df + 6475894 commit 93ce8eb
Show file tree
Hide file tree
Showing 226 changed files with 9,498 additions and 1,863 deletions.
3 changes: 2 additions & 1 deletion .ci/scripts/test_llama_runner_eager.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@ run_and_verify() {
-d fp32 \
--max_seq_length 32 \
--temperature 0 \
--show_tokens \
--prompt "Once upon a time," > result.txt

# Verify result.txt
RESULT=$(cat result.txt)
EXPECTED_RESULT="there was a little girl"
EXPECTED_RESULT="727, 471, 263, 2217, 7826, 4257, 365, 2354, 29889, 2296, 18012, 304, 1708, 5377, 297, 278, 6575, 845, 457, 29889, 3118, 2462, 29892, 1183, 4446, 263"
if [[ "${RESULT}" == *"${EXPECTED_RESULT}"* ]]; then
echo "Actual result: ${RESULT}"
echo "Success"
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ghstack_land.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on:
branches:
- 'gh/cccclai/[0-9]+/base'
- 'gh/dbort/[0-9]+/base'
- 'gh/dvorjackz/[0-9]+/base'
- 'gh/guangy10/[0-9]+/base'
- 'gh/helunwencser/[0-9]+/base'
- 'gh/jorgep31415/[0-9]+/base'
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,6 @@
[submodule "third-party/pybind11"]
path = third-party/pybind11
url = https://github.com/pybind/pybind11.git
[submodule "third-party/ao"]
path = third-party/ao
url = https://github.com/pytorch/ao.git
11 changes: 8 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -721,10 +721,15 @@ if(EXECUTORCH_BUILD_PYBIND)
-fPIC
-frtti
-fexceptions
# libtorch is built with the old ABI, so we need to do the same for any
# .cpp files that include torch, c10, or ATen targets.
-D_GLIBCXX_USE_CXX11_ABI=0
)
if(EXECUTORCH_DO_NOT_USE_CXX11_ABI)
# libtorch is built with the old ABI, so we need to do the same for any
# .cpp files that include torch, c10, or ATen targets. Note that PyTorch
# nightly binary is built with _GLIBCXX_USE_CXX11_ABI set to 0 while its
# CI build sets this to 1 (default)
list(APPEND _pybind_compile_options -D_GLIBCXX_USE_CXX11_ABI=0)
endif()

# util lib
add_library(
util ${CMAKE_CURRENT_SOURCE_DIR}/extension/evalue_util/print_evalue.cpp
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ We recommend using the latest release tag from the
See [CONTRIBUTING.md](CONTRIBUTING.md) for details about issues, PRs, code
style, CI jobs, and other development topics.

To connect with us and other community members, we invite you to join PyTorch Slack community by filling out this [form](https://docs.google.com/forms/d/e/1FAIpQLSeADnUNW36fjKjYzyHDOzEB_abKQE9b6gqqW9NXse6O0MWh0A/viewform). Once you've joined, you can:
* Head to the `#executorch-general` channel for general questions, discussion, and community support.
* Join the `#executorch-contributors` channel if you're interested in contributing directly to project development.


## Directory Structure

```
Expand Down
13 changes: 13 additions & 0 deletions backends/arm/TARGETS
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# @noautodeps
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

python_library(
Expand Down Expand Up @@ -69,6 +70,18 @@ python_library(
],
)

python_library(
name = "tosa_specification",
srcs = [
"tosa_specification.py",
],
typing = True,
deps = [
"fbsource//third-party/pypi/packaging:packaging",
"//executorch/exir/backend:compile_spec_schema",
],
)

python_library(
name = "tosa_utils",
srcs = [
Expand Down
1 change: 1 addition & 0 deletions backends/arm/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ python_library(
deps = [
"//executorch/backends/arm:tosa_quant_utils",
"//executorch/backends/arm:tosa_utils",
"//executorch/backends/xnnpack/_passes:xnnpack_passes",
"//executorch/exir:lib",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
get_first_fake_tensor,
insert_q_dq_pair,
)
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand Down Expand Up @@ -42,6 +42,9 @@ def _transpose_impl(*args, **kwargs):
return args[0]


register_passable_op(torch.ops.passthrough_to_tosa._transpose)


class AnnotateChannelsLastDimOrder(ExportPass):
"""
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
Expand Down
4 changes: 4 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from executorch.backends.arm._passes.decompose_layernorm_pass import (
DecomposeLayerNormPass,
)
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
DecomposeSoftmaxesPass,
Expand All @@ -43,6 +44,7 @@
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
UnsqueezeScalarPlaceholdersPass,
)
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
from executorch.exir import ExportedProgram
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.pass_manager import PassManager
Expand All @@ -58,6 +60,7 @@ def transform_to_backend_pipeline(
):
"""Apply passes before transforming program to backend"""
self.add_pass(CastInt64ToInt32Pass(exported_program))
self.add_pass(RemoveGetItemPass())
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(SizeAdjustConv2DPass())
self.add_pass(RemoveClonePass())
Expand All @@ -72,6 +75,7 @@ def transform_to_backend_pipeline(
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
self.add_pass(DecomposeLinearPass())
for spec in compile_spec:
if spec.key == "permute_memory_format":
memory_format = spec.value.decode()
Expand Down
112 changes: 112 additions & 0 deletions backends/arm/_passes/decompose_linear_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_first_fake_tensor,
)
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class DecomposeLinearPass(ExportPass):
"""
This pass decomposes linear into a Conv2D with the required view operations.
linear(x, weights, bias) becomes:
x_reshaped = view(x)
weights_reshaped = view(weights)
conv2d = conv2d(x_reshaped, weights_reshaped, bias)
output = view(conv2d)
It also inserts q/dq pairs if the linear node was quantized.
"""

def call(self, graph_module):
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
if node.target != exir_ops.edge.aten.linear.default:
continue
args = node.args
input = args[0]
weights = args[1]
bias = args[2] if len(args) > 2 else None
output_shape = get_first_fake_tensor(node).shape
input_shape = get_first_fake_tensor(input).shape
weights_shape = get_first_fake_tensor(weights).shape
batches = int(np.prod(input_shape[:-1])) if len(input_shape) > 1 else 1
# input has shape (..., Ci)
input_reshaped_shape = [batches, input_shape[-1], 1, 1]
# weights have shape (Co, Ci)
weights_reshaped_shape = [weights_shape[0], weights_shape[1], 1, 1]

with graph_module.graph.inserting_before(node):
quantize = input.op == "call_function" and input.target == dq_op
q_params = input.args[1:] if quantize else None
# Reshape input to 4D with shape (N, Ci, 1, 1)
input_reshaped = create_node(
graph=graph_module.graph,
op_target=exir_ops.edge.aten.view_copy.default,
args=(input, input_reshaped_shape),
kwargs={},
quantize=quantize,
q_params=q_params,
)

quantize = weights.op == "call_function" and weights.target == dq_op
q_params = weights.args[1:] if quantize else None
# Reshape weights to 4D with shape (Co, Ci, 1, 1)
weights_reshaped = create_node(
graph=graph_module.graph,
op_target=exir_ops.edge.aten.view_copy.default,
args=(weights, weights_reshaped_shape),
kwargs={},
quantize=quantize,
q_params=q_params,
)

consumer_node = list(node.users)[0]
quantize = (
consumer_node.op == "call_function" and consumer_node.target == q_op
)
q_params = consumer_node.args[1:] if quantize else None
conv = create_node(
graph=graph_module.graph,
op_target=exir_ops.edge.aten.convolution.default,
args=(
input_reshaped,
weights_reshaped,
bias,
[1, 1], # strides
[0, 0], # padding
[1, 1], # dilation
False, # transposed
[0, 0], # output padding
1, # groups
),
kwargs={},
quantize=quantize,
q_params=q_params,
)

with graph_module.graph.inserting_after(conv):
# Reshape output to same rank as original input with shape (..., Co)
# No need to insert q/dq pair as Conv2D node above has inserted them if
# required.
output = create_node(
graph=graph_module.graph,
op_target=exir_ops.edge.aten.view_copy.default,
args=(conv, list(output_shape)),
kwargs={},
)

node.replace_all_uses_with(output)
graph_module.graph.erase_node(node)
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)
14 changes: 1 addition & 13 deletions backends/arm/_passes/insert_squeeze_after_sum_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@

import torch
import torch.fx
from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_q_dq_pair

from executorch.backends.arm.tosa_quant_utils import get_quant_node_args, is_quant_node
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

Expand All @@ -28,8 +26,6 @@ class InsertSqueezeAfterSumPass(ExportPass):
sum(dims, keep_dim = False)
After pass:
sum(dims, keep_dim = True)
(q)
(dq)
squeeze(dim = dims)
"""

Expand All @@ -45,12 +41,6 @@ def call(self, graph_module: torch.fx.GraphModule):
continue

dim_list = cast(list[int], sum_node.args[1])
quantized = is_quant_node(sum_node)
if quantized:
qparams = get_quant_node_args(sum_node.all_input_nodes[0])
qparams = qparams + (torch.int8,)
else:
qparams = None

# Add keep_dim = True arg to sum node.
sum_node.args = sum_node.args[0:2] + (True,)
Expand All @@ -61,8 +51,6 @@ def call(self, graph_module: torch.fx.GraphModule):
)
sum_node.replace_all_uses_with(squeeze_node)
squeeze_node.args = (sum_node, dim_list)
if quantized:
sum_node = insert_q_dq_pair(graph_module.graph, sum_node, qparams)
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/size_adjust_conv2d_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import cast, Optional

import torch.fx
from executorch.backends.arm.tosa_quant_utils import is_quant_node
from executorch.backends.arm.tosa_quant_utils import is_node_quantized
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch._ops import OpOverload
Expand Down Expand Up @@ -113,7 +113,7 @@ def call(self, graph_module: torch.fx.GraphModule):
slice_node = graph.create_node(
"call_function", self.slice_op, (last_node,) + args
)
if is_quant_node(last_node):
if is_node_quantized(last_node):
q_params = last_node.args[1:]
dq_node = insert_q_dq_pair(
graph_module.graph, slice_node, q_params
Expand Down
Loading

0 comments on commit 93ce8eb

Please sign in to comment.