Skip to content

Commit

Permalink
fix build in test runner
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Dec 5, 2024
1 parent 68d7e86 commit 19de8f5
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions js/web/test/test-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -586,11 +586,11 @@ export class TensorResultValidator {
}
}

function createGpuTensorForInput(cpuTensor: ort.Tensor): ort.Tensor {
async function createGpuTensorForInput(cpuTensor: ort.Tensor): Promise<ort.Tensor> {
if (!isGpuBufferSupportedType(cpuTensor.type) || Array.isArray(cpuTensor.data)) {
throw new Error(`createGpuTensorForInput can not work with ${cpuTensor.type} tensor`);
}
const device = ort.env.webgpu.device as GPUDevice;
const device = await ort.env.webgpu.device;
const gpuBuffer = device.createBuffer({
// eslint-disable-next-line no-bitwise
usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE,
Expand All @@ -612,14 +612,14 @@ function createGpuTensorForInput(cpuTensor: ort.Tensor): ort.Tensor {
});
}

function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]) {
async function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]) {
if (!isGpuBufferSupportedType(type)) {
throw new Error(`createGpuTensorForOutput can not work with ${type} tensor`);
}

const size = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(type), dims)!;

const device = ort.env.webgpu.device as GPUDevice;
const device = await ort.env.webgpu.device;
const gpuBuffer = device.createBuffer({
// eslint-disable-next-line no-bitwise
usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE,
Expand Down Expand Up @@ -725,7 +725,7 @@ export async function sessionRun(options: {
if (options.ioBinding === 'ml-location' || options.ioBinding === 'ml-tensor') {
feeds[name] = await createMLTensorForInput(options.mlContext!, feeds[name]);
} else {
feeds[name] = createGpuTensorForInput(feeds[name]);
feeds[name] = await createGpuTensorForInput(feeds[name]);
}
}
}
Expand All @@ -742,7 +742,7 @@ export async function sessionRun(options: {
if (options.ioBinding === 'ml-tensor') {
fetches[name] = await createMLTensorForOutput(options.mlContext!, type, dims);
} else {
fetches[name] = createGpuTensorForOutput(type, dims);
fetches[name] = await createGpuTensorForOutput(type, dims);
}
}
}
Expand Down

0 comments on commit 19de8f5

Please sign in to comment.