Skip to content

Commit

Permalink
Drop null for reshape op
Browse files Browse the repository at this point in the history
  • Loading branch information
Honry committed Nov 29, 2023
1 parent 27039b3 commit 6fd403f
Show file tree
Hide file tree
Showing 11 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion face_recognition/facenet_nhwc.js
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ export class FaceNetNhwc {
}

async buildFullyConnected_(input) {
input = this.builder_.reshape(input, [1, null]);
input = this.builder_.reshape(input, [1, 1792]);
const weights = await buildConstantByNpy(this.builder_,
`${this.weightsUrl_}/Bottleneck_kernel_transpose.npy`);
const bias = await buildConstantByNpy(this.builder_,
Expand Down
2 changes: 1 addition & 1 deletion facial_landmark_detection/face_landmark_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ export class FaceLandmarkNchw {
if (reshapeSize !== undefined) {
gemm = this.builder_.gemm(this.builder_.reshape(
this.builder_.transpose(await input, {permutation: [0, 2, 3, 1]}),
[null, reshapeSize]), await weights, options);
[1, reshapeSize]), await weights, options);
} else {
gemm = this.builder_.gemm(await input, await weights, options);
}
Expand Down
2 changes: 1 addition & 1 deletion facial_landmark_detection/face_landmark_nhwc.js
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ export class FaceLandmarkNhwc {
let fc;
if (reshapeSize !== undefined) {
fc = this.builder_.gemm(this.builder_.reshape(
await input, [null, reshapeSize]), await weights, options);
await input, [1, reshapeSize]), await weights, options);
} else {
fc = this.builder_.gemm(await input, await weights, options);
}
Expand Down
2 changes: 1 addition & 1 deletion image_classification/mobilenet_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ export class MobileNetV2Nchw {

const conv3 = await this.buildConv_(bottleneck15, '95', true);
const pool = this.builder_.averagePool2d(conv3);
const reshape = this.builder_.reshape(pool, [1, null]);
const reshape = this.builder_.reshape(pool, [1, 1280]);
const gemm = await this.buildGemm_(reshape, '104');
return this.builder_.softmax(gemm);
}
Expand Down
2 changes: 1 addition & 1 deletion image_classification/mobilenet_nhwc.js
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ export class MobileNetV2Nhwc {
conv3, {windowDimensions: [7, 7], layout: 'nhwc'});
const conv4 = await this.buildConv_(
averagePool2d, '222', 'Logits_Conv2d_1c_1x1_Conv2D', false, {autoPad, filterLayout});
const reshape = this.builder_.reshape(conv4, [1, null]);
const reshape = this.builder_.reshape(conv4, [1, 1001]);
return this.builder_.softmax(reshape);
}

Expand Down
4 changes: 2 additions & 2 deletions image_classification/resnet50v2_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ export class ResNet50V2Nchw {
const biasName = prefix + '_bias.npy';
const bias = await buildConstantByNpy(this.builder_, biasName);
const options =
{c: this.builder_.reshape(bias, [1, null]), bTranspose: true};
{c: this.builder_.reshape(bias, [1, 1000]), bTranspose: true};
return this.builder_.gemm(input, weight, options);
}

Expand Down Expand Up @@ -148,7 +148,7 @@ export class ResNet50V2Nchw {

const bn3 = await this.buildBatchNorm_(bottleneck16, '2', '');
const pool2 = await this.builder_.averagePool2d(bn3);
const reshape = this.builder_.reshape(pool2, [1, null]);
const reshape = this.builder_.reshape(pool2, [1, 2048]);
const gemm = await this.buildGemm_(reshape, '0');
return this.builder_.softmax(gemm);
}
Expand Down
2 changes: 1 addition & 1 deletion image_classification/resnet50v2_nhwc.js
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ export class ResNet50V2Nhwc {
const mean = this.builder_.averagePool2d(fusedBn, {layout});
const conv2 = await this.buildConv_(
mean, ['', '', 'logits'], {autoPad}, false);
const reshape = this.builder_.reshape(conv2, [1, null]);
const reshape = this.builder_.reshape(conv2, [1, 1001]);
return this.builder_.softmax(reshape);
}

Expand Down
2 changes: 1 addition & 1 deletion image_classification/squeezenet_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ export class SqueezeNetNchw {
const conv25 = await this.buildConv_(fire7, 'conv25');
const pool3 = this.builder_.averagePool2d(
conv25, {windowDimensions: [13, 13], strides: [13, 13]});
const reshape0 = this.builder_.reshape(pool3, [1, null]);
const reshape0 = this.builder_.reshape(pool3, [1, 1000]);
return this.builder_.softmax(reshape0);
}

Expand Down
2 changes: 1 addition & 1 deletion image_classification/squeezenet_nhwc.js
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ export class SqueezeNetNhwc {
const conv10 = await this.buildConv_(fire9, 'conv10');
const averagePool2d = this.builder_.averagePool2d(
conv10, {windowDimensions: [13, 13], layout});
const reshape = this.builder_.reshape(averagePool2d, [1, null]);
const reshape = this.builder_.reshape(averagePool2d, [1, 1001]);
return this.builder_.softmax(reshape);
}

Expand Down
4 changes: 2 additions & 2 deletions lenet/lenet.js
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ export class LeNet {
this.builder_.maxPool2d(add2, {windowDimensions: pool2WindowShape,
strides: pool2Strides});

const reshape1Shape = [1, null];
const reshape1Shape = [1, 800];
const reshape1 = this.builder_.reshape(pool2, reshape1Shape);

// skip the new shape, 2 int64 values
Expand All @@ -100,7 +100,7 @@ export class LeNet {

const relu = this.builder_.relu(add3);

const reshape2Shape = [1, null];
const reshape2Shape = [1, 500];
const reshape2 = this.builder_.reshape(relu, reshape2Shape);

const matmul2Shape = [10, 500];
Expand Down
6 changes: 3 additions & 3 deletions rnnoise/rnnoise.js
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ export class RNNoise {
const vadGruYTransposed = this.builder_.transpose(
vadGruY, {permutation: [2, 0, 1, 3]});
const vadGruTranspose1 = this.builder_.reshape(
vadGruYTransposed, [null, this.frames_, this.vadGruHiddenSize]);
vadGruYTransposed, [1, this.frames_, this.vadGruHiddenSize]);
const concatenate1 = this.builder_.concat(
[inputDenseTanh0, vadGruTranspose1, input], 2);
const noiseGruX = this.builder_.transpose(
Expand Down Expand Up @@ -112,7 +112,7 @@ export class RNNoise {
const noiseGruYTransposed = this.builder_.transpose(
noiseGruY, {permutation: [2, 0, 1, 3]});
const noiseGruTranspose1 = this.builder_.reshape(
noiseGruYTransposed, [null, this.frames_, this.noiseGruHiddenSize]);
noiseGruYTransposed, [1, this.frames_, this.noiseGruHiddenSize]);
const concatenate2 = this.builder_.concat(
[vadGruTranspose1, noiseGruTranspose1, input], 2);
const denoiseGruX = this.builder_.transpose(
Expand Down Expand Up @@ -140,7 +140,7 @@ export class RNNoise {
const denoiseGruYTransposed = this.builder_.transpose(
denoiseGruY, {permutation: [2, 0, 1, 3]});
const denoiseGruTranspose1 = this.builder_.reshape(
denoiseGruYTransposed, [null, this.frames_, this.denoiseGruHiddenSize]);
denoiseGruYTransposed, [1, this.frames_, this.denoiseGruHiddenSize]);
const denoiseOutput0 = this.builder_.matmul(
denoiseGruTranspose1, denoiseOutputKernel0);
const biasedTensorName = this.builder_.add(
Expand Down

0 comments on commit 6fd403f

Please sign in to comment.