Skip to content

Commit

Permalink
Sync infranstructure with Dawn
Browse files Browse the repository at this point in the history
  • Loading branch information
fujunwei authored and mingmingtasd committed May 30, 2022
1 parent cedc10b commit 6d05ff5
Show file tree
Hide file tree
Showing 88 changed files with 568 additions and 638 deletions.
4 changes: 2 additions & 2 deletions DEPS
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ gclient_gn_args = [

vars = {
'chromium_git': 'https://chromium.googlesource.com',
'dawn_git': 'https://dawn.googlesource.com',
'dawn_git': 'https://github.com/fujunwei',
'github_git': 'https://github.com',

'dawn_standalone': True,
Expand Down Expand Up @@ -45,7 +45,7 @@ deps = {

# Dependencies required for code generator and infrastructure code.
'third_party/dawn': {
'url': '{dawn_git}/dawn.git@bf1c0cf52377b4db2bf3a433dc5056620aad7cdd'
'url': '{dawn_git}/dawn.git@f4c84e239bf8b5b2c4733d68ca38e1e9049fd895'
},

# Dependencies required for backends.
Expand Down
3 changes: 2 additions & 1 deletion build_overrides/webnn.gni
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
webnn_standalone = true

# The paths to WebNN's dependencies
webnn_dawn_root = "//third_party/dawn"
webnn_abseil_dir = "//third_party/abseil-cpp"
dawn_root = "//third_party/dawn"
webnn_googletest_dir = "//third_party/googletest"
webnn_jinja2_dir = "//third_party/jinja2"
webnn_gpgmm_dir = "//third_party/gpgmm"
2 changes: 1 addition & 1 deletion examples/LeNet/LeNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ wnn::Graph LeNet::Build(const std::string& weigthsPath) {
return nullptr;
}

const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(mContext);
const wnn::GraphBuilder builder = utils::CreateGraphBuilder(mContext);

uint32_t byteOffset = 0;
const wnn::Operand input = utils::BuildInput(builder, "input", {1, 1, 28, 28});
Expand Down
6 changes: 3 additions & 3 deletions examples/MobileNetV2/Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ int main(int argc, const char* argv[]) {
}
},
&mobilevetv2);
wnn::GraphBuilder builder = wnn::CreateGraphBuilder(context);
wnn::Operand output = mobilevetv2.mLayout == "nchw" ? mobilevetv2.LoadNchw(builder)
: mobilevetv2.LoadNhwc(builder);
wnn::GraphBuilder builder = utils::CreateGraphBuilder(context);
wnn::Operand output = mobilevetv2.mLayout == "nchw" ? mobilevetv2.LoadNCHW(builder)
: mobilevetv2.LoadNHWC(builder);

// Build the graph.
const std::chrono::time_point<std::chrono::high_resolution_clock> compilationStartTime =
Expand Down
2 changes: 1 addition & 1 deletion examples/ResNet/Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ int main(int argc, const char* argv[]) {
}
},
&resnet);
wnn::GraphBuilder builder = wnn::CreateGraphBuilder(context);
wnn::GraphBuilder builder = utils::CreateGraphBuilder(context);
wnn::Operand output =
resnet.mLayout == "nchw" ? resnet.LoadNchw(builder) : resnet.LoadNhwc(builder);

Expand Down
43 changes: 13 additions & 30 deletions examples/SampleUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ static webnn::wire::WireClient* wireClient = nullptr;
static utils::TerribleCommandBuffer* c2sBuf = nullptr;
static utils::TerribleCommandBuffer* s2cBuf = nullptr;

static wnn::Instance clientInstance;
static wnn::Instance instance;
static std::unique_ptr<webnn::native::Instance> nativeInstance;
wnn::Context CreateCppContext(wnn::ContextOptions const* options) {
nativeInstance = std::make_unique<webnn::native::Instance>();
Expand All @@ -67,6 +67,7 @@ wnn::Context CreateCppContext(wnn::ContextOptions const* options) {
case CmdBufType::None:
procs = backendProcs;
context = backendContext;
instance = wnn::Instance(nativeInstance->Get());
break;

case CmdBufType::Terrible: {
Expand Down Expand Up @@ -94,23 +95,21 @@ wnn::Context CreateCppContext(wnn::ContextOptions const* options) {

context = contextReservation.context;
#else
webnnProcSetProcs(&procs);
auto instanceReservation = wireClient->ReserveInstance();
wireServer->InjectInstance(nativeInstance->Get(), instanceReservation.id,
instanceReservation.generation);
// Keep the reference instread of using Acquire.
// TODO:: make the instance in the client as singleton object.
clientInstance = wnn::Instance(instanceReservation.instance);
return clientInstance.CreateContext(options);
instance = wnn::Instance(instanceReservation.instance);
#endif
}
default:
dawn::ErrorLog() << "Invaild CmdBufType";
DAWN_ASSERT(0);
}
webnnProcSetProcs(&procs);

return wnn::Context::Acquire(context);
return instance.CreateContext(options);
;
}

void DoFlush() {
Expand All @@ -123,35 +122,15 @@ void DoFlush() {
}

wnn::NamedInputs CreateCppNamedInputs() {
#if defined(WEBNN_ENABLE_WIRE)
return clientInstance.CreateNamedInputs();
#else
return wnn::CreateNamedInputs();
#endif // defined(WEBNN_ENABLE_WIRE)
}

wnn::NamedOperands CreateCppNamedOperands() {
#if defined(WEBNN_ENABLE_WIRE)
return clientInstance.CreateNamedOperands();
#else
return wnn::CreateNamedOperands();
#endif // defined(WEBNN_ENABLE_WIRE)
return instance.CreateNamedInputs();
}

wnn::NamedOutputs CreateCppNamedOutputs() {
#if defined(WEBNN_ENABLE_WIRE)
return clientInstance.CreateNamedOutputs();
#else
return wnn::CreateNamedOutputs();
#endif // defined(WEBNN_ENABLE_WIRE)
return instance.CreateNamedOutputs();
}

wnn::OperatorArray CreateCppOperatorArray() {
#if defined(WEBNN_ENABLE_WIRE)
return clientInstance.CreateOperatorArray();
#else
return wnn::CreateOperatorArray();
#endif // defined(WEBNN_ENABLE_WIRE)
return instance.CreateOperatorArray();
}

bool ExampleBase::ParseAndCheckExampleOptions(int argc, const char* argv[]) {
Expand Down Expand Up @@ -264,6 +243,10 @@ namespace utils {
return activationOperand;
}

wnn::GraphBuilder CreateGraphBuilder(const wnn::Context& context) {
return instance.CreateGraphBuilder(context);
}

wnn::Operand BuildInput(const wnn::GraphBuilder& builder,
std::string name,
const std::vector<int32_t>& dimensions,
Expand All @@ -283,7 +266,7 @@ namespace utils {
}

wnn::Graph Build(const wnn::GraphBuilder& builder, const std::vector<NamedOperand>& outputs) {
wnn::NamedOperands namedOperands = CreateCppNamedOperands();
wnn::NamedOperands namedOperands = instance.CreateNamedOperands();
for (auto& output : outputs) {
namedOperands.Set(output.name.c_str(), output.operand);
}
Expand Down
5 changes: 3 additions & 2 deletions examples/SampleUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ namespace utils {
FusedActivation activation,
const void* options = nullptr);

wnn::GraphBuilder CreateGraphBuilder(const wnn::Context& context);
wnn::Operand BuildInput(const wnn::GraphBuilder& builder,
std::string name,
const std::vector<int32_t>& dimensions,
Expand Down Expand Up @@ -247,7 +248,7 @@ namespace utils {
void Compute(const wnn::Graph& graph,
const std::vector<NamedInput<T>>& inputs,
const std::vector<NamedOutput<T>>& outputs) {
if (graph.GetHandle() == nullptr) {
if (graph.Get() == nullptr) {
dawn::ErrorLog() << "The graph is invaild.";
}

Expand All @@ -272,7 +273,7 @@ namespace utils {
resource.arrayBufferView.buffer = output.resource.data();
resource.arrayBufferView.byteLength = output.resource.size() * sizeof(float);
mlOutputs.push_back(resource);
namedOutputs.Set(output.name.c_str(), &mlOutputs.back());
namedOutputs.SetOutput(output.name.c_str(), &mlOutputs.back());
}
graph.Compute(namedInputs, namedOutputs);
DoFlush();
Expand Down
2 changes: 1 addition & 1 deletion examples/SqueezeNet/Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ int main(int argc, const char* argv[]) {
}
},
&squeezenet);
wnn::GraphBuilder builder = wnn::CreateGraphBuilder(context);
wnn::GraphBuilder builder = utils::CreateGraphBuilder(context);
wnn::Operand output =
squeezenet.mLayout == "nchw" ? squeezenet.LoadNchw(builder) : squeezenet.LoadNhwc(builder);

Expand Down
20 changes: 11 additions & 9 deletions generator/webnn_generator.gni
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import("//third_party/dawn/generator/generator_lib.gni")
import("//third_party/dawn/generator/dawn_generator.gni")
import("../scripts/webnn_overrides_with_defaults.gni")

# Dawn used to put autogenerated files in a lot of different places. When we
Expand All @@ -39,12 +39,16 @@ import("../scripts/webnn_overrides_with_defaults.gni")
# disallowed gen directories.

webnn_allowed_gen_output_dirs = [
"src/dawn/",
"src/dawn/native/",
"src/webnn/",
"src/webnn/native/",
"src/webnn/wire/client/",
"src/webnn/wire/server/",
"src/webnn/wire/",
"include/webnn/",
"emscripten-bits/",
"include/dawn/",
]

# Template to help invoking Dawn code generators based on generator_lib
Expand Down Expand Up @@ -77,7 +81,7 @@ template("webnn_generator") {
forward_variables_from(invoker, "*")

# Set arguments required to find the python libraries for the generator
generator_lib_dir = "${webnn_dawn_root}/generator"
generator_lib_dir = "${dawn_root}/generator"
jinja2_path = webnn_jinja2_dir

# Force Dawn's autogenerated file structure to mirror exactly the source
Expand All @@ -87,35 +91,33 @@ template("webnn_generator") {

# Make sure that we delete stale autogenerated file in directories that are
# no longer used by code generation to avoid include conflicts.
deps = [ "${webnn_root}/generator:remove_stale_autogen_files" ]
deps = [ "${dawn_root}/generator:remove_stale_autogen_files" ]

template_dir = "${webnn_root}/generator/templates"
template_dir = "${dawn_root}/generator/templates"
}
}

# Helper generator for calling the generator from webnn.json
#
# dawn_json_generator("my_target_gen") {
# webnn_json_generator("my_target_gen") {
# # Which generator target to output
# target = "my_target"
#
# # Also supports `outputs` and `custom_gen_dir` like dawn_generator.
# }
template("webnn_json_generator") {
webnn_generator(target_name) {
script = "${webnn_root}/generator/webnn_json_generator.py"
script = "${dawn_root}/generator/dawn_json_generator.py"

# The base arguments for the generator: from this webnn.json, generate this
# target using templates in this directory.
args = [
"--webnn-json",
"--dawn-json",
rebase_path("${webnn_root}/webnn.json", root_build_dir),
"--wire-json",
rebase_path("${webnn_root}/webnn_wire.json", root_build_dir),
"--targets",
invoker.target,
"--dawn-generator-path",
rebase_path("${webnn_root}/third_party/dawn/generator"),
]

forward_variables_from(invoker, "*", [ "target" ])
Expand Down
19 changes: 8 additions & 11 deletions include/webnn/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import("../../scripts/webnn_overrides_with_defaults.gni")

import("${webnn_dawn_root}/scripts/dawn_component.gni")
import("${dawn_root}/scripts/dawn_component.gni")
import("${webnn_root}/generator/webnn_generator.gni")

###############################################################################
Expand All @@ -25,8 +25,8 @@ import("${webnn_root}/generator/webnn_generator.gni")
webnn_json_generator("headers_gen") {
target = "headers"
outputs = [
"include/webnn/webnn_proc_table.h",
"include/webnn/webnn.h",
"include/dawn/webnn_proc_table.h",
"include/dawn/webnn.h",
]
}

Expand All @@ -43,7 +43,10 @@ source_set("headers") {

webnn_json_generator("cpp_headers_gen") {
target = "cpp_headers"
outputs = [ "include/webnn/webnn_cpp.h" ]
outputs = [
"include/dawn/webnn_cpp.h",
"include/dawn/webnn_cpp_print.h",
]
}

source_set("cpp_headers") {
Expand All @@ -67,14 +70,8 @@ config("public") {

if (build_with_chromium) {
include_dirs += [
"${webnn_dawn_root}/include",
"${dawn_root}/include",
"${dawn_gen_root}/include",
]
} else {
# TODO: Remove after upgrading webnn infranstructure align with dawn.
include_dirs += [
"${webnn_dawn_root}/src/include",
"${dawn_gen_root}/src/include",
]
}
}
Loading

0 comments on commit 6d05ff5

Please sign in to comment.