-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Request the decomposition for gatherElements, scatterElements and scatterND #767
Comments
One possible emulation of scatterND by using tf.scatter_nd: # Make an all True tensor in updates.shape which will be scattered to the condition tensor.
trues = tf.ones(updates.shape, tf.dtypes.bool)
# Scatter the True values into a zero (False) initialized tensor according to indices.
condition = tf.scatter_nd(indices, trues, input.shape)
# Scatter the values of updates into another zero-initialized tensor according to indices.
scatter = tf.scatter_nd(indices, updates, input.shape)
# Select scattered value or input value based on condition.
output = tf.where(condition, scatter, input) Test case
|
Maybe the gatherElements can be supported with tfl.gather_nd, the test cases and the doc show the gather_nd can gether not only slices but also elements, but gather_nd has no axis argument, so the indices need to be converted, for example
Convert to
So the indices must be constant operand and insert the location with loop, the two dimensions of input is required at current stage. |
@fujunwei , thanks for sharing the idea of emulating gatherElements by gatherND!
I suppose the emulation could support N dimensions input as follows: // Generate gatherND's indicesND from gatherElements' indices and axis.
// indicesND.rank == 2 which can be treated as an array of locations.
// indicesND.shape[0] == the number of elements of indices. indicesND.shape[1] == indices.rank.
let indicesND = [];
for (let i = 0; i < indices.numberOfElements(); ++i) {
let location = indices.getLocationFromIndex(i);
location[axis] = indices.getValueByLocation(location);
indicesND.push(location);
}
// For gatherElements, output.shape == indices.shape
reshape(gatherND(input, indicesND), indices.shape); webnn-baseline has a reference implementation of Please note this emulation only works when indices is a constant operand. Revisions: |
Related: mil.ops.defs.iOS15.scatter_gather.gather_along_axis, tf.experimental.numpy.take_along_axis?, DML_GATHER_ELEMENTS
Decomposition: NA
Data Types: input (*), indices (int32, uint32, int64)
Related: mil.ops.defs.iOS17.scatter_gather.scatter_along_axis, TF=?, DML_SCATTER_ELEMENTS
Decomposition: NA
Data Types: input (*), updates (same as input), indices (int32, uint32, int64)
Related: tf.scatter_nd, mil.ops.defs.iOS15.scatter_gather.scatter_nd, DML_SCATTER_ND
Decomposition: NA
Data Types: input (*), updates (same as input), indices (int32, uint32, int64)
The text was updated successfully, but these errors were encountered: