diff --git a/src/logical.js b/src/logical.js index 7955ef7..d60bca9 100644 --- a/src/logical.js +++ b/src/logical.js @@ -34,3 +34,4 @@ export const logicalNot = (input) => logicalNotImpl(input); export const logicalOr = (inputA, inputB) => binary(inputA, inputB, (a, b) => (!!a || !!b ? 1 : 0)); export const logicalXor = (inputA, inputB) => binary(inputA, inputB, (a, b) => (!!a ^ !!b)); +export const notEqual = (inputA, inputB) => logicalNot(equal(inputA, inputB)); diff --git a/test/logical_test.js b/test/logical_test.js index 043e5dd..878d28f 100644 --- a/test/logical_test.js +++ b/test/logical_test.js @@ -1,7 +1,7 @@ 'use strict'; import {equal, greater, greaterOrEqual, lesser, lesserOrEqual, - logicalAnd, logicalNot, logicalOr, logicalXor} from '../src/logical.js'; + logicalAnd, logicalNot, logicalOr, logicalXor, notEqual} from '../src/logical.js'; import {Tensor} from '../src/lib/tensor.js'; import * as utils from './utils.js'; @@ -515,4 +515,52 @@ describe('test logical', function() { }; testLogical(inputA, inputB, expected, logicalXor); }); + + it('notEqual 0D scalar', function() { + const inputA = { + shape: [], + data: [0.5], + }; + const inputB = { + shape: [], + data: [0.5], + }; + const expected = { + shape: [], + data: [0], + }; + testLogical(inputA, inputB, expected, notEqual); + }); + + it('notEqual 4D', function() { + const inputA = { + shape: [1, 2, 2, 1], + data: [-1, 1, 1, 0], + }; + const inputB = { + shape: [1, 2, 2, 1], + data: [1, 0, -1, 0], + }; + const expected = { + shape: [1, 2, 2, 1], + data: [1, 1, 1, 0], + }; + testLogical(inputA, inputB, expected, notEqual); + }); + + it('notEqual 4D broadcast', function() { + const inputA = { + shape: [1, 2, 2, 1], + data: [-1, 1, 1, 0], + }; + const inputB = { + shape: [1], + data: [1], + }; + const expected = { + shape: [1, 2, 2, 1], + data: [1, 0, 0, 1], + }; + testLogical(inputA, inputB, expected, notEqual); + }); });