Skip to content

Commit

Permalink
Add a simple multi-threaded test
Browse files Browse the repository at this point in the history
Differential Revision: D65162642

Pull Request resolved: #7143
  • Loading branch information
digantdesai authored Dec 3, 2024
1 parent b4eda5f commit 5a9e7a4
Show file tree
Hide file tree
Showing 6 changed files with 294 additions and 19 deletions.
Empty file added backends/test/README.md
Empty file.
8 changes: 8 additions & 0 deletions backends/test/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Any targets that should be shared between fbcode and xplat must be defined in
# targets.bzl. This file can contain fbcode-only targets.

load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets(is_fbcode = True)
164 changes: 164 additions & 0 deletions backends/test/multi_method_delegate_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
#include <gtest/gtest.h>

#include <iostream>
#include <string>
#include <thread>
#include <vector>

#include <executorch/runtime/executor/program.h>
#include <executorch/runtime/platform/runtime.h>

#include <executorch/extension/data_loader/file_data_loader.h>
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
#include <executorch/extension/runner_util/inputs.h>

using executorch::runtime::Error;
using executorch::runtime::EValue;
using executorch::runtime::HierarchicalAllocator;
using executorch::runtime::MemoryManager;
using executorch::runtime::Method;
using executorch::runtime::MethodMeta;
using executorch::runtime::Program;
using executorch::runtime::Result;
using executorch::runtime::Span;

using executorch::extension::FileDataLoader;
using executorch::extension::MallocMemoryAllocator;
using executorch::extension::prepare_input_tensors;

/*
* Backend agnostic base class.
*/
class ETPTEMethodRunBaseTest : public ::testing::Test {
protected:
void SetUp() override {
executorch::runtime::runtime_init();
}

// Runs the PTE e2e without using outside resources.
// This will run in a single thread.
// TODO(T208989128) - Add Synchronizer based run method.
void run(
const int id,
const std::string& kTestPTEPath,
const std::string& kMethodName,
std::atomic<size_t>& count) const {
Result<FileDataLoader> loader = FileDataLoader::from(kTestPTEPath.c_str());
ASSERT_EQ(loader.error(), Error::Ok);

Result<Program> program = Program::load(
&loader.get(), Program::Verification::InternalConsistency);
ASSERT_EQ(program.error(), Error::Ok);

Result<MethodMeta> method_meta = program->method_meta(kMethodName.c_str());
ASSERT_EQ(method_meta.error(), Error::Ok);

const size_t num_memory_planned_buffers =
method_meta->num_memory_planned_buffers();

std::vector<std::unique_ptr<uint8_t[]>> planned_buffers;
std::vector<Span<uint8_t>> planned_spans;
for (size_t i = 0; i < num_memory_planned_buffers; ++i) {
const size_t buffer_size =
static_cast<size_t>(method_meta->memory_planned_buffer_size(i).get());
planned_buffers.push_back(std::make_unique<uint8_t[]>(buffer_size));
planned_spans.push_back({planned_buffers.back().get(), buffer_size});
}

auto method_allocator = std::make_unique<MallocMemoryAllocator>();
auto memory_planned_allocator = std::make_unique<HierarchicalAllocator>(
Span(planned_spans.data(), planned_spans.size()));
auto temp_allocator = std::make_unique<MallocMemoryAllocator>();

auto memory_manager = std::make_unique<MemoryManager>(
method_allocator.get(),
memory_planned_allocator.get(),
temp_allocator.get());

Result<Method> method =
program->load_method(kMethodName.c_str(), memory_manager.get());
ASSERT_EQ(method.error(), Error::Ok);

auto inputs = prepare_input_tensors(*method);
ASSERT_EQ(inputs.error(), Error::Ok);

Error err = method->execute();
for (int i = 0; i < id % 7; i++) {
err = method->execute();
ASSERT_EQ(err, Error::Ok);
}

std::vector<EValue> outputs(method->outputs_size());
err = method->get_outputs(outputs.data(), outputs.size());
ET_CHECK(err == Error::Ok);
// TODO(T208989129) - Add validation of outputs using bundled
// inputs/outputs.
count++;
}
};

class XNNPACKMultiDelegateTest : public ETPTEMethodRunBaseTest {
protected:
std::string kTestPTE1Path, kTestPTE2Path;
std::string kMethodName;
int num_threads;

void SetUp() override {
ETPTEMethodRunBaseTest::SetUp();
const char* pte1_path =
std::getenv("ET_XNNPACK_GENERATED_ADD_LARGE_PTE_PATH");
if (pte1_path == nullptr) {
std::cerr << "ET_XNNPACK_GENERATED_ADD_LARGE_PTE_PATH is not set"
<< std::endl;
FAIL();
}
kTestPTE1Path = std::string(pte1_path);

const char* pte2_path =
std::getenv("ET_XNNPACK_GENERATED_SUB_LARGE_PTE_PATH");
if (pte1_path == nullptr) {
std::cerr << "ET_XNNPACK_GENERATED_SUB_LARGE_PTE_PATH is not set"
<< std::endl;
FAIL();
}
kTestPTE2Path = std::string(pte2_path);

num_threads = 40;
kMethodName = "forward";
}
};

// This test is to validate the assumption that the delegate is thread safe.
// That includes the following:
// 1. The delegate can be initilized by multiple threads in parallel.
// 2. The delegate can be executed by multiple threads in parallel.
// 3. The delegate can be destroyed by multiple threads in parallel.
// Regardless of the underlying implementation of the delegate.
// This is particularly important when we have shared resources across
// delegate instances through a singleton backend instance.
TEST_F(XNNPACKMultiDelegateTest, MultipleThreads) {
ASSERT_NE(kTestPTE1Path.size(), 0);
ASSERT_NE(kTestPTE2Path.size(), 0);
ASSERT_NE(num_threads, 0);
ASSERT_NE(kMethodName.size(), 0);

std::vector<std::thread> threads(num_threads);
std::atomic<size_t> count{0};

for (int i = 0; i < num_threads; i++) {
threads[i] = std::thread([&, i]() {
run(i, i % 7 ? kTestPTE1Path : kTestPTE2Path, kMethodName, count);
});
}
for (int i = 0; i < num_threads; i++) {
threads[i].join();
}
ASSERT_EQ(count, num_threads);
}

// TODO(T208989291): Add more tests here. For example,
// - PTEs with multiple methods
// - PTEs with proucer and consumer relationships in different threads
// - PTEs with more than 1 delegate instances
// - PTEs with different type of delegate instances
// - Add more patterns of delegate initialization and execution
29 changes: 29 additions & 0 deletions backends/test/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets(is_fbcode = False):
"""Defines targets that should be shared between fbcode and xplat.
The directory containing this targets.bzl file should also contain both
TARGETS and BUCK files that call this function.
"""
if not runtime.is_oss and is_fbcode:
modules_env = {
"ET_XNNPACK_GENERATED_ADD_LARGE_PTE_PATH": "$(location fbcode//executorch/test/models:exported_xnnp_delegated_programs[ModuleAddLarge.pte])",
"ET_XNNPACK_GENERATED_SUB_LARGE_PTE_PATH": "$(location fbcode//executorch/test/models:exported_xnnp_delegated_programs[ModuleSubLarge.pte])",
}

runtime.cxx_test(
name = "multi_method_delegate_test",
srcs = [
"multi_method_delegate_test.cpp",
],
deps = [
"//executorch/runtime/executor:program",
"//executorch/extension/data_loader:file_data_loader",
"//executorch/extension/memory_allocator:malloc_memory_allocator",
"//executorch/kernels/portable:generated_lib",
"//executorch/backends/xnnpack:xnnpack_backend",
"//executorch/extension/runner_util:inputs",
],
env = modules_env,
)
88 changes: 69 additions & 19 deletions test/models/export_delegated_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import executorch.exir as exir

import torch
from executorch.exir import to_edge
from executorch.exir import EdgeCompileConfig, to_edge, to_edge_transform_and_lower
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
from executorch.exir.backend.test.backend_with_compiler_demo import (
Expand Down Expand Up @@ -52,6 +52,41 @@ def get_random_inputs(self) -> Sequence[torch.Tensor]:
return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2))


class ModuleAddLarge(nn.Module):
def __init__(self):
super().__init__()

def forward(
self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
) -> torch.Tensor:
x: torch.Tensor = torch.add(a, b)
y: torch.Tensor = torch.add(x, c)
z: torch.Tensor = torch.add(x, y)
return z

def get_random_inputs(self) -> Sequence[torch.Tensor]:
n = 10 # to create a large tensor
return (torch.ones(n, n, n), 2 * torch.ones(n, n, n), 3 * torch.ones(n, n, n))


class ModuleSubLarge(nn.Module):
def __init__(self):
super().__init__()

def forward(
self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
) -> torch.Tensor:
x: torch.Tensor = torch.sub(a, b)
y: torch.Tensor = torch.sub(x, c)
z: torch.Tensor = torch.sub(x, y)
w: torch.Tensor = torch.sub(z, c)
return w

def get_random_inputs(self) -> Sequence[torch.Tensor]:
n = 10 # to create a large tensor
return (torch.ones(n, n, n), 2 * torch.ones(n, n, n), 3 * torch.ones(n, n, n))


#
# Backends
#
Expand Down Expand Up @@ -95,30 +130,45 @@ def __init__(self, fn):
def forward(self, *args, **kwargs):
return self.fn(*args, **kwargs)

edge: exir.EdgeProgramManager = to_edge(
export(WrapperModule(getattr(eager_module, method)), args=inputs)
exported_program = export(WrapperModule(getattr(eager_module, method)), args=inputs)

edge_config = EdgeCompileConfig(_check_ir_validity=False)
et_config = exir.ExecutorchBackendConfig(
extract_delegate_segments=extract_delegate_segments,
constant_tensor_alignment=constant_tensor_alignemnt,
delegate_alignment=delegate_alignment,
)

lowered_module = to_backend(backend_id, edge.exported_program(), compile_specs=[])
if backend_id == "XnnpackBackend":
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
XnnpackPartitioner,
)

class CompositeModule(nn.Module):
def __init__(self):
super().__init__()
self.lowered_module = lowered_module
executorch_program = to_edge_transform_and_lower(
exported_program,
compile_config=edge_config,
partitioner=[XnnpackPartitioner()],
).to_executorch(config=et_config)
else:
edge: exir.EdgeProgramManager = to_edge(exported_program)
lowered_module = to_backend(
backend_id, edge.exported_program(), compile_specs=[]
)

def forward(self, *args, **kwargs):
return self.lowered_module(*args, **kwargs)
class CompositeModule(nn.Module):
def __init__(self):
super().__init__()
self.lowered_module = lowered_module

composite_module = CompositeModule()
composite_module(*inputs)
def forward(self, *args, **kwargs):
return self.lowered_module(*args, **kwargs)

executorch_program = to_edge(export(composite_module, args=inputs)).to_executorch(
config=exir.ExecutorchBackendConfig(
extract_delegate_segments=extract_delegate_segments,
constant_tensor_alignment=constant_tensor_alignemnt,
delegate_alignment=delegate_alignment,
)
)
composite_module = CompositeModule()
composite_module(*inputs)

executorch_program = to_edge(
export(composite_module, args=inputs)
).to_executorch(config=et_config)

return executorch_program.buffer

Expand Down
24 changes: 24 additions & 0 deletions test/models/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,17 @@ def define_common_targets():
par_style = "xar",
deps = [
":export_delegated_program_lib",
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",

],
visibility = [], # Private
)

# Class names of nn.Modules for :exported_delegated_programs to export.
DELEGATED_MODULES_TO_EXPORT = [
"ModuleAddMul",
"ModuleAddLarge",
"ModuleSubLarge",
]

# Name of the backend to use when exporting delegated programs.
Expand Down Expand Up @@ -153,3 +157,23 @@ def define_common_targets():
"//executorch/test/...",
],
)

runtime.genrule(
name = "exported_xnnp_delegated_programs",
cmd = "$(exe :export_delegated_program)" +
" --modules " + ",".join(DELEGATED_MODULES_TO_EXPORT) +
" --backend_id " + "XnnpackBackend" +
" --outdir $OUT",
outs = {
fname + ".pte": [fname + ".pte"]
for fname in DELEGATED_MODULES_TO_EXPORT
},
default_outs = ["."],
visibility = [
"//executorch/runtime/executor/test/...",
"//executorch/backends/test/...",
"//executorch/test/...",
"@EXECUTORCH_CLIENTS",
],
env = {"PYTORCH_DISABLE_JUSTKNOBS": "1",},
)

0 comments on commit 5a9e7a4

Please sign in to comment.