Skip to content

Commit

Permalink
Fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Sep 15, 2023
1 parent 65b0020 commit 70a844c
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 17 deletions.
3 changes: 1 addition & 2 deletions js/.eslintrc.js
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ module.exports = {
'no-new-wrappers': 'error',
'no-octal': 'error',
'no-octal-escape': 'error',
'no-param-reassign': 'error',
'no-param-reassign': 'off',
'no-redeclare': 'off',
'@typescript-eslint/no-redeclare': ['error'],
'no-regex-spaces': 'error',
Expand Down Expand Up @@ -158,7 +158,6 @@ module.exports = {
'@typescript-eslint/restrict-plus-operands': 'off',
'import/no-internal-modules': 'off',
'prefer-arrow/prefer-arrow-functions': 'off',
'no-param-reassign': 'off',
'guard-for-in': 'off'
}
}, {
Expand Down
3 changes: 1 addition & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/binary-like-util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ type CreateOpProgramShader =
doBroadcast: boolean, funcCall: BinaryFunctionCall, typeOutput: number, additionalImplementation?: string) =>
string;

/* eslint-disable no-param-reassign */
const createOpProgramInfo =
(metadata: ProgramMetadata, inputs: readonly TensorView[], funcCall: BinaryFunctionCall,
createOpProgramShader: CreateOpProgramShader, additionalImplementation?: string,
Expand Down Expand Up @@ -77,7 +76,7 @@ const createOpProgramInfo =
};

// This is used for ops like binary, where.
export const createOpProgramInfoLoader =
export const createBinaryLikeProgramInfoLoader =
(inputs: readonly TensorView[], name: string, funcCall: BinaryFunctionCall,
createOpProgramShader: CreateOpProgramShader, additionalImplementation?: string, cacheKey?: string,
outputDataType?: number): ProgramInfoLoader => {
Expand Down
26 changes: 15 additions & 11 deletions js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {ComputeContext} from '../types';

import {BinaryCustomExpression, BinaryFunctionCall, createOpProgramInfoLoader, fourAssignment, getBroadcastIndexComponent} from './binary-like-util';
import {BinaryCustomExpression, BinaryFunctionCall, createBinaryLikeProgramInfoLoader, fourAssignment, getBroadcastIndexComponent} from './binary-like-util';
import {createBroadcastHelper, inputVariable, outputVariable, ShaderHelper} from './common';

const createOpProgramShader =
Expand Down Expand Up @@ -81,27 +81,30 @@ const createOpProgramShader =
};

export const add = (context: ComputeContext): void => {
context.compute(createOpProgramInfoLoader(context.inputs, 'Add', (a, b) => `${a}+${b}`, createOpProgramShader));
context.compute(
createBinaryLikeProgramInfoLoader(context.inputs, 'Add', (a, b) => `${a}+${b}`, createOpProgramShader));
};

export const div = (context: ComputeContext): void => {
context.compute(createOpProgramInfoLoader(context.inputs, 'Div', (a, b) => `${a}/${b}`, createOpProgramShader));
context.compute(
createBinaryLikeProgramInfoLoader(context.inputs, 'Div', (a, b) => `${a}/${b}`, createOpProgramShader));
};

export const equal = (context: ComputeContext): void => {
context.compute(createOpProgramInfoLoader(
context.compute(createBinaryLikeProgramInfoLoader(
context.inputs, 'Equal', ({scalar: (a, b) => `u32(${a}==${b})`, vector: (a, b) => `vec4<u32>(${a}==${b})`}),
createOpProgramShader, undefined, undefined, DataType.bool));
};

export const mul = (context: ComputeContext): void => {
context.compute(createOpProgramInfoLoader(context.inputs, 'Mul', (a, b) => `${a}*${b}`, createOpProgramShader));
context.compute(
createBinaryLikeProgramInfoLoader(context.inputs, 'Mul', (a, b) => `${a}*${b}`, createOpProgramShader));
};

export const pow = (context: ComputeContext): void => {
const type = inputVariable('input', context.inputs[0].dataType, context.inputs[0].dims).type.value;
const roundStr = type === 'i32' ? 'round' : '';
context.compute(createOpProgramInfoLoader(
context.compute(createBinaryLikeProgramInfoLoader(
context.inputs, 'Pow',
({scalar: (a, b) => `pow_custom(${a},${b})`, vector: (a, b) => `pow_vector_custom(${a},${b})`}),
createOpProgramShader,
Expand All @@ -123,30 +126,31 @@ export const pow = (context: ComputeContext): void => {
};

export const sub = (context: ComputeContext): void => {
context.compute(createOpProgramInfoLoader(context.inputs, 'Sub', (a, b) => `${a}-${b}`, createOpProgramShader));
context.compute(
createBinaryLikeProgramInfoLoader(context.inputs, 'Sub', (a, b) => `${a}-${b}`, createOpProgramShader));
};

export const greater = (context: ComputeContext): void => {
context.compute(createOpProgramInfoLoader(
context.compute(createBinaryLikeProgramInfoLoader(
context.inputs, 'Greater', ({scalar: (a, b) => `u32(${a}>${b})`, vector: (a, b) => `vec4<u32>(${a}>${b})`}),
createOpProgramShader, undefined, undefined, DataType.bool));
};

export const less = (context: ComputeContext): void => {
context.compute(createOpProgramInfoLoader(
context.compute(createBinaryLikeProgramInfoLoader(
context.inputs, 'Less', ({scalar: (a, b) => `u32(${a}<${b})`, vector: (a, b) => `vec4<u32>(${a}<${b})`}),
createOpProgramShader, undefined, undefined, DataType.bool));
};

export const greaterOrEqual = (context: ComputeContext): void => {
context.compute(createOpProgramInfoLoader(
context.compute(createBinaryLikeProgramInfoLoader(
context.inputs, 'GreaterOrEqual',
({scalar: (a, b) => `u32(${a}>=${b})`, vector: (a, b) => `vec4<u32>(${a}>=${b})`}), createOpProgramShader,
undefined, undefined, DataType.bool));
};

export const lessOrEqual = (context: ComputeContext): void => {
context.compute(createOpProgramInfoLoader(
context.compute(createBinaryLikeProgramInfoLoader(
context.inputs, 'LessOrEqual', ({scalar: (a, b) => `u32(${a}<=${b})`, vector: (a, b) => `vec4<u32>(${a}<=${b})`}),
createOpProgramShader, undefined, undefined, DataType.bool));
};
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/where.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {ComputeContext} from '../types';

import {BinaryCustomExpression, BinaryFunctionCall, createOpProgramInfoLoader, fourAssignment, getBroadcastIndexComponent} from './binary-like-util';
import {BinaryCustomExpression, BinaryFunctionCall, createBinaryLikeProgramInfoLoader , fourAssignment, getBroadcastIndexComponent} from './binary-like-util';
import {createBroadcastHelper, inputVariable, outputVariable, ShaderHelper} from './common';

const createOpProgramShader =
Expand Down Expand Up @@ -89,6 +89,6 @@ const createOpProgramShader =
};

export const where = (context: ComputeContext): void => {
context.compute(createOpProgramInfoLoader(
context.compute(createBinaryLikeProgramInfoLoader (
context.inputs, 'Where', (a, b, c) => `select(${b}, ${a}, ${c})`, createOpProgramShader));
};

0 comments on commit 70a844c

Please sign in to comment.