diff --git a/lib/api/onnx-impl.ts b/lib/api/onnx-impl.ts index 1f20accf..52f4896d 100644 --- a/lib/api/onnx-impl.ts +++ b/lib/api/onnx-impl.ts @@ -8,6 +8,7 @@ import {WebGLBackend} from '../backends/backend-webgl'; import {Environment} from './env'; import {envImpl} from './env-impl'; import {Backend} from './onnx'; +import {MixedBackend} from '../backends/backend-mixed'; export * from './env'; export * from './onnx'; @@ -17,7 +18,8 @@ export * from './inference-session'; export const backend: Backend = { cpu: new CpuBackend(), wasm: new WasmBackend(), - webgl: new WebGLBackend() + webgl: new WebGLBackend(), + mixed: new MixedBackend() }; export const ENV: Environment = envImpl; diff --git a/lib/backend.ts b/lib/backend.ts index 5df89984..ed135a6b 100644 --- a/lib/backend.ts +++ b/lib/backend.ts @@ -82,7 +82,7 @@ const backendsCache: Map = new Map(); */ export async function Backend(hint?: string|ReadonlyArray): Promise { if (!hint) { - return Backend(['webgl', 'wasm', 'cpu']); + return Backend(['webgl', 'mixed', 'wasm', 'cpu']); } else { const hints = typeof hint === 'string' ? [hint] : hint; diff --git a/lib/backends/backend-mixed.ts b/lib/backends/backend-mixed.ts new file mode 100644 index 00000000..f4f91e46 --- /dev/null +++ b/lib/backends/backend-mixed.ts @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {SessionHandler} from '../backend'; +import {Session} from '../session'; +import {MixedSessionHandler} from './mixed-session-handler'; +import {WebGLBackend} from './backend-webgl'; + + +export class MixedBackend extends WebGLBackend { + createSessionHandler(context: Session.Context): SessionHandler { + return new MixedSessionHandler(this, context); + } +} diff --git a/lib/backends/mixed-session-handler.ts b/lib/backends/mixed-session-handler.ts new file mode 100644 index 00000000..fcff77df --- /dev/null +++ b/lib/backends/mixed-session-handler.ts @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {WebGLSessionHandler} from './webgl/session-handler'; +import {Graph} from '../graph'; +import {OpSet, resolveOperator} from '../opset'; +import {Operator} from '../operators'; +import {CPU_OP_RESOLVE_RULES} from './cpu/op-resolve-rules'; +import {Logger} from '../instrument'; + +export class MixedSessionHandler extends WebGLSessionHandler { + resolve(node: Graph.Node, opsets: ReadonlyArray): Operator { + try { + return super.resolve(node, opsets); + } catch (e) { + Logger.warning( + 'MixedSessionHandler', + `Unable to initialize operator '${node.opType}' with webgl. trying with cpu...`); + const op = resolveOperator(node, opsets, CPU_OP_RESOLVE_RULES); + op.initialize(node.attributes); + return op; + } + } +} \ No newline at end of file