Skip to content

Commit

Permalink
Update base for Update on "update llama runner to decode single token"
Browse files Browse the repository at this point in the history
Right now, we don't print the generated response in the eager runner until all tokens are generated. This is not good experience as we need to wait until all tokens are generated to see the response.

This PR updates it to decode each new token immediately after it is generated.

Differential Revision: [D65578306](https://our.internmc.facebook.com/intern/diff/D65578306/)

[ghstack-poisoned]
  • Loading branch information
helunwencser committed Nov 8, 2024
2 parents 1c0c17c + 39e5b91 commit 148e99c
Show file tree
Hide file tree
Showing 67 changed files with 3,932 additions and 2,355 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ghstack_land.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on:
branches:
- 'gh/cccclai/[0-9]+/base'
- 'gh/dbort/[0-9]+/base'
- 'gh/dvorjackz/[0-9]+/base'
- 'gh/guangy10/[0-9]+/base'
- 'gh/helunwencser/[0-9]+/base'
- 'gh/jorgep31415/[0-9]+/base'
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
UnsqueezeScalarPlaceholdersPass,
)
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
from executorch.exir import ExportedProgram
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.pass_manager import PassManager
Expand All @@ -58,6 +59,7 @@ def transform_to_backend_pipeline(
):
"""Apply passes before transforming program to backend"""
self.add_pass(CastInt64ToInt32Pass(exported_program))
self.add_pass(RemoveGetItemPass())
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(SizeAdjustConv2DPass())
self.add_pass(RemoveClonePass())
Expand Down
9 changes: 3 additions & 6 deletions backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import logging
import os
from typing import final, List, Optional
from typing import cast, final, List, Optional

import serializer.tosa_serializer as ts
from executorch.backends.arm.arm_vela import vela_compile
Expand All @@ -31,6 +31,7 @@
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch.export.exported_program import ExportedProgram
from torch.fx import Node

# TOSA backend debug functionality
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -225,6 +226,7 @@ def preprocess( # noqa: C901
node_visitors = get_node_visitors(edge_program)

for node in graph_module.graph.nodes:
node = cast(Node, node)
if node.op == "call_function":
process_call_function(node, tosa_graph, node_visitors)
elif node.op == "placeholder":
Expand All @@ -236,9 +238,6 @@ def preprocess( # noqa: C901
# any checking of compatibility.
dbg_fail(node, tosa_graph, artifact_path)

# TODO: It would be awesome if this dump could somehow be done on top level and not here.
# Problem is that the desc.json has to be created on the tosa_graph object, which we can't
# access from top level.
if artifact_path:
tag = _get_first_delegation_tag(graph_module)
dbg_tosa_dump(
Expand All @@ -259,6 +258,4 @@ def preprocess( # noqa: C901
else:
raise RuntimeError(f"Unknown format {output_format}")

# Continueing from above. Can I put tosa_graph into this function?
# debug_handle_map = ...
return PreprocessResult(processed_bytes=binary)
1 change: 1 addition & 0 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.native_layer_norm.default,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.max_pool2d_with_indices.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.repeat.default,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
op_get_item,
op_hardtanh,
op_log,
op_max_pool2d,
op_mm,
op_mul,
op_permute,
Expand Down
77 changes: 77 additions & 0 deletions backends/arm/operators/op_max_pool2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import cast, List

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_utils import get_quant_node_args

from serializer.tosa_serializer import TosaOp


@register_node_visitor
class MaxPool2dVisitor(NodeVisitor):
target = "aten.max_pool2d.default"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:

input_tensor = inputs[0]
kernel_size = inputs[1].special
stride = inputs[2].special

try:
padding = [*inputs[3].special, *inputs[3].special]
except IndexError:
padding = [0, 0, 0, 0]

accumulator_type = input_tensor.dtype

if is_quant_node:
# Accumulator type always is int8 when input tensor is an integer type.
accumulator_type = ts.DType.INT8

# Initilize zero point to zero.
input_zp = 0
output_zp = 0

if is_quant_node:
input_zp = get_quant_node_args(
cast(torch.fx.Node, node.all_input_nodes[0])
).zp
output_zp = get_quant_node_args(list(node.users)[0]).zp

attr = ts.TosaSerializerAttribute()
attr.PoolAttribute(
kernel=kernel_size,
stride=stride,
pad=padding,
input_zp=input_zp,
output_zp=output_zp,
accum_dtype=accumulator_type,
)

tosa_graph.addOperator(
TosaOp.Op().MAX_POOL2D,
[input_tensor.name],
[output.name],
attr,
)
1 change: 1 addition & 0 deletions backends/arm/quantizer/arm_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def is_share_obs_or_fq_op(op: Callable) -> bool:
# TODO: remove?
torch.ops.aten.adaptive_avg_pool2d.default,
torch.ops.aten.avg_pool2d.default,
torch.ops.aten.max_pool2d.default,
torch.ops.aten.full.default,
torch.ops.aten.flatten.using_ints,
torch.ops.aten.dropout.default,
Expand Down
23 changes: 15 additions & 8 deletions backends/arm/test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ def pytest_sessionfinish(session, exitstatus):

# ==== End of Pytest hooks =====

# ==== Custom Pytest decorators =====


def expectedFailureOnFVP(test_item):
if is_option_enabled("corstone300"):
test_item.__unittest_expecting_failure__ = True
return test_item


# ==== End of Custom Pytest decorators =====


def load_libquantized_ops_aot_lib():
so_ext = {
Expand Down Expand Up @@ -181,19 +192,15 @@ def get_tosa_compile_spec_unbuilt(
the compile spec before calling .build() to finalize it.
"""
if not custom_path:
intermediate_path = maybe_get_tosa_collate_path() or tempfile.mkdtemp(
prefix="arm_tosa_"
)
else:
intermediate_path = custom_path
custom_path = maybe_get_tosa_collate_path()

if not os.path.exists(intermediate_path):
os.makedirs(intermediate_path, exist_ok=True)
if custom_path is not None and not os.path.exists(custom_path):
os.makedirs(custom_path, exist_ok=True)
compile_spec_builder = (
ArmCompileSpecBuilder()
.tosa_compile_spec()
.set_permute_memory_format(permute_memory_to_nhwc)
.dump_intermediate_artifacts_to(intermediate_path)
.dump_intermediate_artifacts_to(custom_path)
)

return compile_spec_builder
Expand Down
5 changes: 4 additions & 1 deletion backends/arm/test/misc/test_debug_feats.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ def test_numerical_diff_prints(self):
ArmTester(
model,
example_inputs=model.get_inputs(),
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
compile_spec=common.get_tosa_compile_spec(
permute_memory_to_nhwc=True,
custom_path=tempfile.mkdtemp("diff_print_test"),
),
)
.export()
.to_edge()
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_cat_tosa_MI(self, operands: tuple[torch.Tensor, ...], dim: int):
def test_cat_4d_tosa_MI(self):
square = torch.ones((2, 2, 2, 2))
for dim in range(-3, 3):
test_data = ((square, square), dim)
test_data = ((square, square.clone()), dim)
self._test_cat_tosa_MI_pipeline(self.Cat(), test_data)

@parameterized.expand(Cat.test_parameters)
Expand Down
Loading

0 comments on commit 148e99c

Please sign in to comment.