diff --git a/.github/workflows/ghstack_land.yml b/.github/workflows/ghstack_land.yml index d5ebbcbbb4..12782c66dd 100644 --- a/.github/workflows/ghstack_land.yml +++ b/.github/workflows/ghstack_land.yml @@ -11,6 +11,7 @@ on: - 'gh/kimishpatel/[0-9]+/base' - 'gh/kirklandsign/[0-9]+/base' - 'gh/larryliu0820/[0-9]+/base' + - 'gh/lucylq/[0-9]+/base' - 'gh/manuelcandales/[0-9]+/base' - 'gh/mcr229/[0-9]+/base' - 'gh/swolchok/[0-9]+/base' diff --git a/README.md b/README.md index be7ff32229..da2cb82ef9 100644 --- a/README.md +++ b/README.md @@ -92,7 +92,7 @@ tools. ├── runtime # Core C++ runtime. | ├── backend # Backend delegate runtime APIs. | ├── core # Core structures used across all levels of the runtime. -| ├── executor # Model loading, initalization, and execution. +| ├── executor # Model loading, initialization, and execution. | ├── kernel # Kernel registration and management. | ├── platform # Layer between architecture specific code and portable C++. ├── schema # ExecuTorch PTE file format flatbuffer diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index d03e4a1385..7309287998 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -62,6 +62,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.relu.default, exir_ops.edge.aten.rsqrt.default, exir_ops.edge.aten._softmax.default, + exir_ops.edge.aten.select_copy.int, exir_ops.edge.aten._log_softmax.default, exir_ops.edge.aten.slice_copy.Tensor, exir_ops.edge.aten.sub.Tensor, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index babfbcfea0..a8ddf1c8f0 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -28,6 +28,7 @@ op_relu, op_repeat, op_rsqrt, + op_select, op_sigmoid, op_slice, op_squeeze, diff --git a/backends/arm/operators/op_select.py b/backends/arm/operators/op_select.py new file mode 100644 index 0000000000..6037ed000c --- /dev/null +++ b/backends/arm/operators/op_select.py @@ -0,0 +1,69 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +import serializer.tosa_serializer as ts +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) + +from executorch.backends.arm.tosa_mapping import TosaArg + +from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape +from serializer.tosa_serializer import TosaOp +from torch.fx import Node + + +@register_node_visitor +class SelectVisitor(NodeVisitor): + target = "aten.select_copy.int" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + + assert len(inputs) == 3 + input_node, dim, index = inputs + shape = input_node.shape + rank = len(shape) + + dim = dim.number % rank if dim.number < 0 else dim.number + index = index.number % rank if index.number < 0 else index.number + + # For aten.select_copy, the output will be rank[input_shape - 1] + # For TOSA rank(in) == rank(out). + # Add an intermediate with the same rank + expanded_shape = tuple(1 if i == dim else shape[i] for i in range(rank)) + expanded_shape = tosa_shape(expanded_shape, input_node.dim_order) + + output_reshaped = tosa_graph.addIntermediate( + expanded_shape, ts.DType.INT8 if is_quant_node else output.dtype + ) + + attr_slice = ts.TosaSerializerAttribute() + + start_attr = [index if i == dim else 0 for i in input_node.dim_order] + size_attr = [ + 1 if i == dim else input_node.shape[i] for i in input_node.dim_order + ] + + attr_slice.SliceAttribute(start_attr, size_attr) + + tosa_graph.addOperator( + TosaOp.Op().SLICE, [input_node.name], [output_reshaped.name], attr_slice + ) + + # Reshape back to original rank of output. + build_reshape(tosa_graph, output_reshaped.name, output.shape, output.name) diff --git a/backends/arm/quantizer/quantization_annotation/generic_annotator.py b/backends/arm/quantizer/quantization_annotation/generic_annotator.py index a490991693..f91df1398e 100644 --- a/backends/arm/quantizer/quantization_annotation/generic_annotator.py +++ b/backends/arm/quantizer/quantization_annotation/generic_annotator.py @@ -34,6 +34,8 @@ # torch.ops.aten.view_as_real.default, # torch.ops.aten.view_as_real_copy.default, torch.ops.aten.view_copy.default, + torch.ops.aten.select.int, + torch.ops.aten.select_copy.int, torch.ops.aten.slice.Tensor, torch.ops.aten.slice_copy.Tensor, # 'concat' should be handled separately as it has a sequence of inputs and diff --git a/backends/arm/test/ops/test_select.py b/backends/arm/test/ops/test_select.py new file mode 100644 index 0000000000..fdb2fa1463 --- /dev/null +++ b/backends/arm/test/ops/test_select.py @@ -0,0 +1,198 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.exir.backend.compile_spec_schema import CompileSpec +from parameterized import parameterized + +test_data_t = tuple[torch.Tensor, int, int] + +test_data_suite: list[tuple[test_data_t]] = [ + # (test_data, dim, index) + ((torch.zeros(5, 3, 20), -1, 0),), + ((torch.zeros(5, 3, 20), 0, -1),), + ((torch.zeros(5, 3, 20), 0, 4),), + ((torch.ones(10, 10, 10), 0, 2),), + ((torch.rand(5, 3, 20, 2), 0, 2),), + ((torch.rand(10, 10) - 0.5, 0, 0),), + ((torch.randn(10) + 10, 0, 1),), + ((torch.randn(10) - 10, 0, 2),), + ((torch.arange(-16, 16, 0.2), 0, 1),), +] + + +class TestSelect(unittest.TestCase): + class SelectCopy(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, dim: int, index: int): + return torch.select_copy(x, dim=dim, index=index) + + class SelectInt(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, dim: int, index: int): + return torch.select(x, dim=dim, index=index) + + def _test_select_tosa_MI_pipeline( + self, + module: torch.nn.Module, + test_data: test_data_t, + export_target: str, + ): + # For 4D tensors, do not permute to NHWC + permute = False if len(test_data[0].shape) == 4 else True + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec( + permute_memory_to_nhwc=permute + ), + ) + .export() + .check([export_target]) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_select_tosa_BI_pipeline( + self, + module: torch.nn.Module, + test_data: test_data_t, + export_target: str, + ): + # For 4D tensors, do not permute to NHWC + permute = False if len(test_data[0].shape) == 4 else True + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec( + permute_memory_to_nhwc=permute + ), + ) + .quantize() + .export() + .check([export_target]) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .dump_artifact() + .dump_operator_distribution() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_select_ethos_BI_pipeline( + self, + compile_spec: list[CompileSpec], + module: torch.nn.Module, + test_data: test_data_t, + export_target: str, + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize() + .export() + .check([export_target]) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .dump_artifact() + .dump_operator_distribution() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + ) + + def _test_select_tosa_u55_BI_pipeline( + self, module: torch.nn.Module, test_data: test_data_t, export_target: str + ): + # For 4D tensors, do not permute to NHWC + permute = False if len(test_data[0].shape) == 4 else True + self._test_select_ethos_BI_pipeline( + common.get_u55_compile_spec(permute_memory_to_nhwc=permute), + module, + test_data, + export_target, + ) + + def _test_select_tosa_u85_BI_pipeline( + self, module: torch.nn.Module, test_data: test_data_t, export_target: str + ): + # For 4D tensors, do not permute to NHWC + permute = False if len(test_data[0].shape) == 4 else True + self._test_select_ethos_BI_pipeline( + common.get_u85_compile_spec(permute_memory_to_nhwc=permute), + module, + test_data, + export_target, + ) + + @parameterized.expand(test_data_suite) + def test_select_copy_tosa_MI(self, test_data: test_data_t): + self._test_select_tosa_MI_pipeline( + self.SelectCopy(), test_data, export_target="torch.ops.aten.select_copy.int" + ) + + @parameterized.expand(test_data_suite) + def test_select_int_tosa_MI(self, test_data: test_data_t): + self._test_select_tosa_MI_pipeline( + self.SelectInt(), test_data, export_target="torch.ops.aten.select.int" + ) + + @parameterized.expand(test_data_suite) + def test_select_copy_tosa_BI(self, test_data: test_data_t): + self._test_select_tosa_BI_pipeline( + self.SelectCopy(), test_data, export_target="torch.ops.aten.select_copy.int" + ) + + @parameterized.expand(test_data_suite) + def test_select_int_tosa_BI(self, test_data: test_data_t): + self._test_select_tosa_BI_pipeline( + self.SelectInt(), test_data, export_target="torch.ops.aten.select.int" + ) + + @parameterized.expand(test_data_suite) + def test_select_copy_tosa_u55_BI(self, test_data: test_data_t): + self._test_select_tosa_u55_BI_pipeline( + self.SelectCopy(), test_data, export_target="torch.ops.aten.select_copy.int" + ) + + @parameterized.expand(test_data_suite) + def test_select_int_tosa_u55_BI(self, test_data: test_data_t): + self._test_select_tosa_u55_BI_pipeline( + self.SelectInt(), test_data, export_target="torch.ops.aten.select.int" + ) + + @parameterized.expand(test_data_suite) + def test_select_copy_tosa_u85_BI(self, test_data: test_data_t): + self._test_select_tosa_u85_BI_pipeline( + self.SelectCopy(), test_data, export_target="torch.ops.aten.select_copy.int" + ) + + @parameterized.expand(test_data_suite) + def test_select_int_tosa_u85_BI(self, test_data: test_data_t): + self._test_select_tosa_u85_BI_pipeline( + self.SelectInt(), test_data, export_target="torch.ops.aten.select.int" + ) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index ac4417c79a..e860a2bfcc 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -188,7 +188,7 @@ def quantized_relu_meta( out_multiplier: torch.Tensor, out_shift: torch.Tensor, ) -> torch.Tensor: - return X.new_empty(X.size(), dtype=torch.uint8) + return X.new_empty(X.size(), dtype=X.dtype) @register_fake("cadence::quantized_matmul") diff --git a/backends/cadence/hifi/operators/dequantize_per_tensor.cpp b/backends/cadence/hifi/operators/dequantize_per_tensor.cpp index f025a1cc6f..18381a26e0 100644 --- a/backends/cadence/hifi/operators/dequantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/dequantize_per_tensor.cpp @@ -45,7 +45,10 @@ void dequantize_per_tensor_out( const int32_t* input_data = input.const_data_ptr(); dequantize(out_data, input_data, scale, zero_point, numel); } else { - ET_CHECK_MSG(false, "Unhandled input dtype %hhd", input.scalar_type()); + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(input.scalar_type())); } } diff --git a/backends/cadence/hifi/operators/quantize_per_tensor.cpp b/backends/cadence/hifi/operators/quantize_per_tensor.cpp index 9cc84fffa3..c65d62968f 100644 --- a/backends/cadence/hifi/operators/quantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/quantize_per_tensor.cpp @@ -49,7 +49,10 @@ void quantize_per_tensor_out( cadence::impl::HiFi::kernels::quantize( out_data, input_data, 1. / scale, zero_point, numel); } else { - ET_CHECK_MSG(false, "Unhandled input dtype %hhd", out.scalar_type()); + ET_CHECK_MSG( + false, + "Unhandled output dtype %hhd", + static_cast(out.scalar_type())); } } diff --git a/backends/cadence/reference/operators/quantized_conv_out.cpp b/backends/cadence/reference/operators/quantized_conv_out.cpp index b37c5884c1..de19f3ef43 100644 --- a/backends/cadence/reference/operators/quantized_conv_out.cpp +++ b/backends/cadence/reference/operators/quantized_conv_out.cpp @@ -248,6 +248,11 @@ void quantized_conv_out( output_scale, (int8_t)output_zero_point, per_tensor_quantized); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(input.scalar_type())); } } diff --git a/backends/cadence/reference/operators/quantized_linear_out.cpp b/backends/cadence/reference/operators/quantized_linear_out.cpp index a02794c179..7bb1bf6fb4 100644 --- a/backends/cadence/reference/operators/quantized_linear_out.cpp +++ b/backends/cadence/reference/operators/quantized_linear_out.cpp @@ -17,8 +17,8 @@ using executorch::aten::Tensor; using executorch::runtime::getLeadingDims; using executorch::runtime::KernelRuntimeContext; -void quantized_linear_out( - KernelRuntimeContext& ctx, +template +void inline _typed_quantized_linear( const Tensor& src, const Tensor& weight, const Tensor& bias, @@ -27,14 +27,11 @@ void quantized_linear_out( const Tensor& out_multiplier, const Tensor& out_shift, int64_t out_zero_point, - const executorch::aten::optional& offset, Tensor& out) { - // Assuming uint8_t for now, but needs to be updated for other quantization - // types - const uint8_t* __restrict__ src_data = src.const_data_ptr(); - const uint8_t* __restrict__ weight_data = weight.const_data_ptr(); + const T* __restrict__ src_data = src.const_data_ptr(); + const T* __restrict__ weight_data = weight.const_data_ptr(); const int32_t* __restrict__ bias_data = bias.const_data_ptr(); - uint8_t* __restrict__ out_data = out.mutable_data_ptr(); + T* __restrict__ out_data = out.mutable_data_ptr(); int32_t weight_zero_point = weight_zero_point_t.const_data_ptr()[0]; @@ -71,11 +68,53 @@ void quantized_linear_out( (weight_data[j * N + k] - weight_zero_point); } out_data[i * M + j] = - kernels::quantize(sum, out_scale, out_zero_point); + kernels::quantize(sum, out_scale, out_zero_point); } } } +void quantized_linear_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& src, + const Tensor& weight, + const Tensor& bias, + int64_t src_zero_point, + const Tensor& weight_zero_point_t, + const Tensor& out_multiplier, + const Tensor& out_shift, + int64_t out_zero_point, + __ET_UNUSED const executorch::aten::optional& offset, + Tensor& out) { + if (out.scalar_type() == executorch::aten::ScalarType::Byte) { + _typed_quantized_linear( + src, + weight, + bias, + src_zero_point, + weight_zero_point_t, + out_multiplier, + out_shift, + out_zero_point, + out); + } else if (out.scalar_type() == executorch::aten::ScalarType::Char) { + _typed_quantized_linear( + src, + weight, + bias, + src_zero_point, + weight_zero_point_t, + out_multiplier, + out_shift, + out_zero_point, + out); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(src.scalar_type())); + } +} + }; // namespace native }; // namespace reference }; // namespace impl diff --git a/backends/cadence/reference/operators/quantized_matmul_out.cpp b/backends/cadence/reference/operators/quantized_matmul_out.cpp index bf901105ea..d12fc533e7 100644 --- a/backends/cadence/reference/operators/quantized_matmul_out.cpp +++ b/backends/cadence/reference/operators/quantized_matmul_out.cpp @@ -144,6 +144,11 @@ void quantized_matmul_out( out_zero_point, transposed, out); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(X.scalar_type())); } } diff --git a/devtools/bundled_program/schema/scalar_type.fbs b/devtools/bundled_program/schema/scalar_type.fbs index a8da080c67..fc299ac691 100644 --- a/devtools/bundled_program/schema/scalar_type.fbs +++ b/devtools/bundled_program/schema/scalar_type.fbs @@ -24,9 +24,14 @@ enum ScalarType : byte { QINT32 = 14, QUINT4X2 = 16, QUINT2X4 = 17, + BITS16 = 22, // Types currently not implemented. // COMPLEXHALF = 8, // COMPLEXFLOAT = 9, // COMPLEXDOUBLE = 10, // BFLOAT16 = 15, + // BITS1x8 = 18, + // BITS2x4 = 19, + // BITS4x2 = 20, + // BITS8 = 21, } diff --git a/devtools/etdump/etdump_flatcc.cpp b/devtools/etdump/etdump_flatcc.cpp index 4c05bb5ace..cfd1d2ae14 100644 --- a/devtools/etdump/etdump_flatcc.cpp +++ b/devtools/etdump/etdump_flatcc.cpp @@ -55,6 +55,8 @@ executorch_flatbuffer_ScalarType_enum_t get_flatbuffer_scalar_type( return executorch_flatbuffer_ScalarType_DOUBLE; case exec_aten::ScalarType::Bool: return executorch_flatbuffer_ScalarType_BOOL; + case exec_aten::ScalarType::Bits16: + return executorch_flatbuffer_ScalarType_BITS16; default: ET_CHECK_MSG( 0, diff --git a/devtools/etdump/scalar_type.fbs b/devtools/etdump/scalar_type.fbs index a8da080c67..fc299ac691 100644 --- a/devtools/etdump/scalar_type.fbs +++ b/devtools/etdump/scalar_type.fbs @@ -24,9 +24,14 @@ enum ScalarType : byte { QINT32 = 14, QUINT4X2 = 16, QUINT2X4 = 17, + BITS16 = 22, // Types currently not implemented. // COMPLEXHALF = 8, // COMPLEXFLOAT = 9, // COMPLEXDOUBLE = 10, // BFLOAT16 = 15, + // BITS1x8 = 18, + // BITS2x4 = 19, + // BITS4x2 = 20, + // BITS8 = 21, } diff --git a/examples/llm_pte_finetuning/runner.py b/examples/llm_pte_finetuning/runner.py index 1800ae7cc3..0deebcf010 100644 --- a/examples/llm_pte_finetuning/runner.py +++ b/examples/llm_pte_finetuning/runner.py @@ -98,7 +98,7 @@ def main() -> None: # for us to update with the gradients in-place. # See https://github.com/pytorch/executorch/blob/main/extension/pybindings/pybindings.cpp#L736 # for more info. - out = et_mod.forward((tokens, labels), clone_outputs=False) # pyre-ignore + out = et_mod.forward((tokens, labels), clone_outputs=False) loss = out[0] losses.append(loss.item()) diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 9bd16fa7c0..d328adffbf 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -126,10 +126,6 @@ runtime.python_library( runtime.python_binary( name = "eval_llama", main_function = "executorch.examples.models.llama.eval_llama.main", - preload_deps = [ - "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", - "//executorch/kernels/quantized:aot_lib", - ], deps = [ ":eval_library", "//caffe2:torch", diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index 285d2f874d..f0ef5d6758 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -293,6 +293,7 @@ def eval_llama( # Needed for loading mmlu dataset. # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files + # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks` if args.tasks and "mmlu" in args.tasks: import datasets @@ -302,7 +303,7 @@ def eval_llama( with torch.no_grad(): eval_results = simple_evaluate( model=eval_wrapper, - tasks=args.tasks, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks` + tasks=args.tasks, num_fewshot=args.num_fewshot, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot` limit=args.limit, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit` ) diff --git a/examples/models/llama/evaluate/eager_eval.py b/examples/models/llama/evaluate/eager_eval.py index d7f4dae78f..b3f04ef3bb 100644 --- a/examples/models/llama/evaluate/eager_eval.py +++ b/examples/models/llama/evaluate/eager_eval.py @@ -47,6 +47,10 @@ def eot_token_id(self): return self._tokenizer.eot_id return self._tokenizer.eos_id + @property + def prefix_token_id(self): + return self.eot_token_id + @property def max_length(self): return self._max_seq_length diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 5016ba1fcb..f3822b6866 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -665,14 +665,16 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 from functools import partial + # pyre-ignore from executorch.backends.qualcomm.quantizer.custom_annotation import ( get_custom_quant_ios_dtype, ) + # pyre-ignore tag_quant_io( builder_exported_to_edge.edge_manager.exported_program().graph_module, partial( - get_custom_quant_ios_dtype, + get_custom_quant_ios_dtype, # pyre-ignore builder_exported_to_edge.model.layers[ 0 ].attention.kv_cache.past_k_caches.shape, diff --git a/examples/models/llama/runner/TARGETS b/examples/models/llama/runner/TARGETS index 2341af9282..34cdd62be7 100644 --- a/examples/models/llama/runner/TARGETS +++ b/examples/models/llama/runner/TARGETS @@ -1,8 +1,37 @@ # Any targets that should be shared between fbcode and xplat must be defined in # targets.bzl. This file can contain fbcode-only targets. +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") load(":targets.bzl", "define_common_targets") oncall("executorch") define_common_targets() + +runtime.python_library( + name = "eager_runner_library", + srcs = [ + "eager.py", + "generation.py" + ], + _is_external_target = True, + base_module = "executorch.examples.models.llama.runner", + visibility = [ + "//bento/...", + "//bento_kernels/...", + "//executorch/examples/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//executorch/examples/models/llama:export_library", + ], +) + +runtime.python_binary( + name = "eager", + main_function = "executorch.examples.models.llama.runner.eager.main", + deps = [ + ":eager_runner_library", + "//caffe2:torch", + ], +) diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index e116e08a09..b8792151a0 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -9,14 +9,13 @@ from typing import Optional import torch - -from examples.models.llama.llama_transformer import ModelArgs from executorch.examples.models.llama.export_llama_lib import ( _prepare_for_llama_export, build_args_parser as _build_args_parser, ) +from executorch.examples.models.llama.llama_transformer import ModelArgs from executorch.examples.models.llama.runner.generation import LlamaRunner -from executorch.extension.llm.export import LLMEdgeManager +from executorch.extension.llm.export.builder import LLMEdgeManager class EagerLlamaRunner(LlamaRunner): @@ -43,8 +42,8 @@ def __init__(self, args): def forward( self, - tokens: Optional[torch.LongTensor] = None, - input_pos: Optional[torch.LongTensor] = None, + tokens: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.model.forward(tokens=tokens, input_pos=input_pos) diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index e332e0ebe2..867c41aabe 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -15,7 +15,7 @@ class CompletionPrediction(TypedDict, total=False): generation: str - tokens: List[str] # not required + tokens: List[int] # not required def sample_top_p(probs, p): @@ -47,6 +47,7 @@ def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int: if temperature > 0: probs = torch.softmax(logits / temperature, dim=-1) return sample_top_p(probs, top_p).item() + # Pyre-ignore[7]: Incompatible return type [7]: Expected `int` but got `Union[bool, float, int]` return torch.argmax(logits, dim=-1).item() @@ -60,8 +61,8 @@ def __init__(self, tokenizer_path: str, model_args: ModelArgs, device: str = "cp @abstractmethod def forward( self, - tokens: Optional[torch.LongTensor] = None, - input_pos: Optional[torch.LongTensor] = None, + tokens: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, ) -> torch.Tensor: pass diff --git a/examples/models/llama/runner/native.py b/examples/models/llama/runner/native.py index 90e7fc46dd..73005d9333 100644 --- a/examples/models/llama/runner/native.py +++ b/examples/models/llama/runner/native.py @@ -42,8 +42,8 @@ def __init__(self, args): def forward( self, - tokens: Optional[torch.LongTensor] = None, - input_pos: Optional[torch.LongTensor] = None, + tokens: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, ) -> torch.Tensor: return ( self.model.forward((tokens, input_pos)) diff --git a/exir/scalar_type.py b/exir/scalar_type.py index b789a09f3a..5d41038610 100644 --- a/exir/scalar_type.py +++ b/exir/scalar_type.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from enum import IntEnum @@ -26,4 +28,4 @@ class ScalarType(IntEnum): BFLOAT16 = 15 QUINT4x2 = 16 QUINT2x4 = 17 - Bits16 = 22 + BITS16 = 22 diff --git a/exir/tensor.py b/exir/tensor.py index d63ed5d262..a40bef4e5e 100644 --- a/exir/tensor.py +++ b/exir/tensor.py @@ -262,7 +262,7 @@ def memory_format_enum(memory_format: torch.memory_format) -> int: torch.qint32: ScalarType.QINT32, torch.bfloat16: ScalarType.BFLOAT16, torch.quint4x2: ScalarType.QUINT4x2, - torch.uint16: ScalarType.Bits16, + torch.uint16: ScalarType.BITS16, } diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 1049b9da30..96097311d8 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -42,13 +42,13 @@ bool utf8_check_validity(const char* str, size_t length) { uint8_t next_byte = static_cast(str[i + 1]); if ((byte & 0xE0) == 0xC0 && (next_byte & 0xC0) == 0x80) { // 2-byte sequence - i += 2; + i += 1; } else if ( (byte & 0xF0) == 0xE0 && (next_byte & 0xC0) == 0x80 && (i + 2 < length) && (static_cast(str[i + 2]) & 0xC0) == 0x80) { // 3-byte sequence - i += 3; + i += 2; } else if ( (byte & 0xF8) == 0xF0 && (next_byte & 0xC0) == 0x80 && (i + 2 < length) && @@ -56,7 +56,7 @@ bool utf8_check_validity(const char* str, size_t length) { (i + 3 < length) && (static_cast(str[i + 3]) & 0xC0) == 0x80) { // 4-byte sequence - i += 4; + i += 3; } else { return false; // Invalid sequence } diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 5470670a4c..e28a8c73cc 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -219,7 +219,7 @@ def pt2e_calibrate( from executorch.examples.models.llama.eval_llama_lib import ( GraphModuleEvalWrapper, ) - from lm_eval.evaluator import simple_evaluate # pyre-ignore[21] + from lm_eval.evaluator import simple_evaluate except ImportError: raise ImportError( "Please install the llm eval dependency via examples/models/llama/install_requirements.sh" diff --git a/extension/pybindings/pybindings.pyi b/extension/pybindings/pybindings.pyi index 818df1f760..fc44ce388a 100644 --- a/extension/pybindings/pybindings.pyi +++ b/extension/pybindings/pybindings.pyi @@ -33,11 +33,20 @@ class ExecuTorchModule: """ # pyre-ignore[2, 3]: "Any" in parameter and return type annotations. - def __call__(self, inputs: Any) -> List[Any]: ... + def __call__(self, inputs: Any, clone_outputs: bool = True) -> List[Any]: ... # pyre-ignore[2, 3]: "Any" in parameter and return type annotations. - def run_method(self, method_name: str, inputs: Sequence[Any]) -> List[Any]: ... + def run_method( + self, + method_name: str, + inputs: Sequence[Any], # pyre-ignore[2]: "Any" in parameter type annotations. + clone_outputs: bool = True, + ) -> List[Any]: ... # pyre-ignore[2, 3]: "Any" in parameter and return type annotations. - def forward(self, inputs: Sequence[Any]) -> List[Any]: ... + def forward( + self, + inputs: Sequence[Any], # pyre-ignore[2]: "Any" in parameter type annotations. + clone_outputs: bool = True, + ) -> List[Any]: ... # pyre-ignore[3]: "Any" in return type annotations. def plan_execute(self) -> List[Any]: ... # Bundled program methods. diff --git a/extension/training/TARGETS b/extension/training/TARGETS new file mode 100644 index 0000000000..d7a1a16137 --- /dev/null +++ b/extension/training/TARGETS @@ -0,0 +1,20 @@ +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() + +python_library( + name = "lib", + srcs = [ + "__init__.py", + ], + deps = [ + "//executorch/extension/training/pybindings:_training_lib", + "//executorch/extension/training/pybindings:_training_module", + ], +) diff --git a/extension/training/__init__.py b/extension/training/__init__.py index e69de29bb2..f5e0254bf8 100644 --- a/extension/training/__init__.py +++ b/extension/training/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from executorch.extension.training.pybindings._training_lib import get_sgd_optimizer + +from executorch.extension.training.pybindings._training_module import ( + _load_for_executorch_for_training, + _load_for_executorch_for_training_from_buffer, + TrainingModule, +) + +__all__ = [ + "get_sgd_optimizer", + "TrainingModule", + "_load_for_executorch_for_training_from_buffer", + "_load_for_executorch_for_training", +] diff --git a/extension/training/pybindings/TARGETS b/extension/training/pybindings/TARGETS new file mode 100644 index 0000000000..6aa11ea672 --- /dev/null +++ b/extension/training/pybindings/TARGETS @@ -0,0 +1,40 @@ +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() + +runtime.cxx_python_extension( + name = "_training_lib", + srcs = [ + "_training_lib.cpp", + ], + base_module = "executorch.extension.training.pybindings", + types = ["_training_lib.pyi"], + visibility = ["//executorch/extension/training/..."], + deps = [ + "//executorch/extension/aten_util:aten_bridge", + "//executorch/extension/training/optimizer:sgd", + ], + external_deps = [ + "pybind11", + "libtorch_python", + ], +) + +runtime.python_library( + name = "_training_module", + srcs = [ + "_training_module.py", + ], + base_module = "executorch.extension.training.pybindings", + visibility = ["//executorch/extension/training/..."], + deps = [ + "//caffe2:torch", + "//executorch/extension/pybindings:portable_lib", + ], +) diff --git a/extension/training/pybindings/_training_lib.cpp b/extension/training/pybindings/_training_lib.cpp new file mode 100644 index 0000000000..59cd11be4a --- /dev/null +++ b/extension/training/pybindings/_training_lib.cpp @@ -0,0 +1,151 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include +#include "executorch/extension/tensor/tensor.h" +#include "executorch/extension/training/optimizer/sgd.h" +#ifndef USE_ATEN_LIB +#include +#endif + +namespace py = pybind11; + +namespace executorch { +namespace extension { +namespace training { + +namespace { + +struct PySGD final { + explicit PySGD( + const py::dict& named_params, + double lr, + double momentum, + double dampening, + double weight_decay, + bool nesterov) + : sgd_(nullptr), + fqns_() +#ifndef USE_ATEN_LIB + , + params_() +#endif + { + std::map cpp_inputs; + auto py_named_params = + py::cast>(named_params); + const auto params_size = py::len(named_params); + fqns_ = std::vector(); + fqns_.reserve(params_size); + + for (auto pair : py_named_params) { + fqns_.push_back(pair.first); + exec_aten::string_view v{fqns_.back().c_str(), pair.first.size()}; +#ifndef USE_ATEN_LIB + // convert at::Tensor to torch::executor::Tensor + params_.emplace_back(alias_tensor_ptr_to_attensor(pair.second)); + cpp_inputs.insert({v, *params_.back()}); +#else + cpp_inputs.insert({v, pair.second}); +#endif + } + sgd_ = std::make_unique( + cpp_inputs, + extension::training::optimizer::SGDOptions( + lr, momentum, dampening, weight_decay, nesterov)); + } + + // Not needed for now, so just delete. + PySGD(const PySGD&) = delete; + PySGD& operator=(const PySGD&) = delete; + PySGD(PySGD&&) = delete; + PySGD& operator=(PySGD&&) = delete; + + void step(const py::dict& py_dict) { + auto py_named_gradients = + py::cast>(py_dict); + std::map cpp_inputs; + + std::vector fqn; +#ifndef USE_ATEN_LIB + std::vector et_tensors; +#endif + + // Convert python objects into cpp. + for (const auto& pair : py_named_gradients) { + fqn.push_back(pair.first); + auto at_tensor = pair.second; + // alias_etensor_to_attensor will assert on this later, so to better + // propogate up to python we check early and throw an exception. + if (!at_tensor.is_contiguous()) { + auto error_msg = "Gradient is not contiguous."; + throw std::runtime_error(error_msg); + } +#ifndef USE_ATEN_LIB + // convert at::Tensor to torch::executor::Tensor + auto temp = alias_tensor_ptr_to_attensor(at_tensor); + et_tensors.push_back(temp); + cpp_inputs.insert({pair.first.c_str(), *et_tensors.back()}); +#else + cpp_inputs.insert({pair.first.c_str(), at_tensor}); +#endif + } + + auto err = sgd_->step(cpp_inputs); + if (err != runtime::Error::Ok) { + throw std::runtime_error("SGD step failed"); + } + } + + private: + // TODO(jakeszwe): Write an optimizer interface and use it here instead of SGD + // specifically. + std::unique_ptr sgd_ = nullptr; + std::vector fqns_; + +#ifndef USE_ATEN_LIB // Portable mode + std::vector params_; +#endif + ; +}; + +static std::unique_ptr get_sgd_optimizer( + const py::dict& named_params, + double lr, + double momentum = 0, + double dampening = 0, + double weight_decay = 0, + bool nesterov = false) { + return std::make_unique( + named_params, lr, momentum, dampening, weight_decay, nesterov); +} + +} // namespace + +PYBIND11_MODULE(_training_lib, m) { + m.def( + "get_sgd_optimizer", + &get_sgd_optimizer, + py::arg("named_params"), + py::arg("lr") = 0.1, + py::arg("momentum") = 0.0, + py::arg("dampening") = 0.0, + py::arg("weight_decay") = 0.0, + py::arg("nesterov") = false); + py::class_(m, "ExecuTorchSGD").def("step", &PySGD::step); +} + +} // namespace training +} // namespace extension +} // namespace executorch diff --git a/extension/training/pybindings/_training_lib.pyi b/extension/training/pybindings/_training_lib.pyi new file mode 100644 index 0000000000..826ee8b164 --- /dev/null +++ b/extension/training/pybindings/_training_lib.pyi @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Sequence, Tuple + +from executorch.exir._warnings import experimental +from torch import Tensor + +@experimental("This API is experimental and subject to change without notice.") +class ExecuTorchSGD: + """SGD Optimizer. + + .. warning:: + + This API is experimental and subject to change without notice. + """ + + def step(self, named_gradients: Dict[str, Tensor]) -> None: + """Take a step in the direction of the gradients.""" + ... + +@experimental("This API is experimental and subject to change without notice.") +def get_sgd_optimizer( + named_parameters: Dict[str, Tensor], + lr: float, + momentum: float = 0, + dampening: float = 0, + weight_decay: float = 0, + nesterov: bool = False, +) -> ExecuTorchSGD: + """Creates an sgd optimizer that operates on the passed in named_parameters according to the specified hyper parameters. + + .. warning:: + + This API is experimental and subject to change without notice. + ... + """ + ... diff --git a/extension/training/pybindings/_training_module.py b/extension/training/pybindings/_training_module.py new file mode 100644 index 0000000000..27333551c3 --- /dev/null +++ b/extension/training/pybindings/_training_module.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Any, Dict, List, Sequence + +from executorch.exir._warnings import experimental + +from executorch.extension.pybindings.portable_lib import ( + _load_for_executorch, + _load_for_executorch_from_buffer, + ExecuTorchModule, +) +from torch import Tensor + + +@experimental("This API is experimental and subject to change without notice.") +class TrainingModule: + def __init__(self, module: ExecuTorchModule): + self.model = module + + self.gradients_method_prefix = "__et_training_gradients_index_" + self.parameters_method_prefix = "__et_training_parameters_index_" + self.fqn_method_prefix = "__et_training_fqn_" + + self.named_grads = None + self.named_params = None + + def forward_backward(self, method_name: str, inputs: Sequence[Any]) -> List[Any]: + # The default ET model returns a large list of outputs that can logically be + # separated into [user outputs, gradients, parameters]. Can use these metadata + # methods to slice the list into the correct parts. + grad_start_idx = self.model.run_method( + self.gradients_method_prefix + method_name, () + )[0] + params_start_idx = self.model.run_method( + self.parameters_method_prefix + method_name, () + )[0] + + full_outputs = self.model.run_method(method_name, inputs) + + user_outs = full_outputs[:grad_start_idx] + grads = full_outputs[grad_start_idx:params_start_idx] + params = full_outputs[params_start_idx:] + + # Important that the outputs are not cloned because we need the optimizer to + # be able to mutate the actual weights and not clones of them. + fqn = self.model.run_method( + self.fqn_method_prefix + method_name, (), clone_outputs=False + ) + + self.named_grads = dict(zip(fqn, grads)) + if self.named_params is None: + self.named_params = dict(zip(fqn, params)) + + return user_outs + + def named_gradients(self) -> Dict[str, Tensor]: + if self.named_grads is None: + raise RuntimeError("Must call forward_backward before named_grads") + return self.named_grads + + def named_parameters(self) -> Dict[str, Tensor]: + if self.named_grads is None: + raise RuntimeError( + "Must call forward_backward before named_params. This will be fixed in a later version" + ) + return self.named_params + + +@experimental("This API is experimental and subject to change without notice.") +def _load_for_executorch_for_training(path: str) -> TrainingModule: + et_module = _load_for_executorch(path) + return TrainingModule(et_module) + + +@experimental("This API is experimental and subject to change without notice.") +def _load_for_executorch_for_training_from_buffer( + buffer: bytes, +) -> TrainingModule: + et_module = _load_for_executorch_from_buffer(buffer) + return TrainingModule(et_module) diff --git a/extension/training/pybindings/targets.bzl b/extension/training/pybindings/targets.bzl new file mode 100644 index 0000000000..d27fd6f48f --- /dev/null +++ b/extension/training/pybindings/targets.bzl @@ -0,0 +1,8 @@ +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + + pass diff --git a/extension/training/pybindings/test/TARGETS b/extension/training/pybindings/test/TARGETS new file mode 100644 index 0000000000..34a324dc50 --- /dev/null +++ b/extension/training/pybindings/test/TARGETS @@ -0,0 +1,14 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.python_test( + name = "test", + srcs = ["test.py"], + visibility = ["//executorch/extension/training/pybindings/test/..."], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/extension/training:lib", + ], +) diff --git a/extension/training/pybindings/test/test.py b/extension/training/pybindings/test/test.py new file mode 100644 index 0000000000..b8feb8558c --- /dev/null +++ b/extension/training/pybindings/test/test.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +import torch +from executorch.exir import to_edge + +from executorch.extension.training import ( + _load_for_executorch_for_training_from_buffer, + get_sgd_optimizer, +) +from torch.export.experimental import _export_forward_backward + + +class TestTraining(unittest.TestCase): + class ModuleSimpleTrain(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + self.loss = torch.nn.CrossEntropyLoss() + + def forward(self, x, y): + return self.loss(self.linear(x).softmax(dim=0), y) + + def get_random_inputs(self): + return (torch.randn(3), torch.tensor([1.0, 0.0, 0.0])) + + def test(self): + m = self.ModuleSimpleTrain() + ep = torch.export.export(m, m.get_random_inputs()) + ep = _export_forward_backward(ep) + ep = to_edge(ep) + ep = ep.to_executorch() + buffer = ep.buffer + tm = _load_for_executorch_for_training_from_buffer(buffer) + + tm.forward_backward("forward", m.get_random_inputs()) + orig_param = list(tm.named_parameters().values())[0].clone() + optimizer = get_sgd_optimizer( + tm.named_parameters(), + 0.1, + 0, + 0, + 0, + False, + ) + optimizer.step(tm.named_gradients()) + self.assertFalse( + torch.allclose(orig_param, list(tm.named_parameters().values())[0]) + ) diff --git a/extension/training/targets.bzl b/extension/training/targets.bzl new file mode 100644 index 0000000000..d27fd6f48f --- /dev/null +++ b/extension/training/targets.bzl @@ -0,0 +1,8 @@ +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + + pass diff --git a/schema/scalar_type.fbs b/schema/scalar_type.fbs index a8da080c67..fc299ac691 100644 --- a/schema/scalar_type.fbs +++ b/schema/scalar_type.fbs @@ -24,9 +24,14 @@ enum ScalarType : byte { QINT32 = 14, QUINT4X2 = 16, QUINT2X4 = 17, + BITS16 = 22, // Types currently not implemented. // COMPLEXHALF = 8, // COMPLEXFLOAT = 9, // COMPLEXDOUBLE = 10, // BFLOAT16 = 15, + // BITS1x8 = 18, + // BITS2x4 = 19, + // BITS4x2 = 20, + // BITS8 = 21, }