You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
I have an observation and I'm hoping someone can advise.
I have a scenario where I maintain a large table of vectors, so a basic (n,m) array of n vectors of size m. I then have some system that generates indices into this table, and I want to pull rows out of the table at those indices. (more background? A hash table version of NeRF is what we're making).
So, I have a set of indices, and I want to gather the rows out of the table to make use of elsewhere. There's two operators in mxnet that will do the job - gather_nd and take
I could have more than 100k, even 1000k indices:
at 100k take will do a forward pass in less than 1 ms, but a backward pass will take about 45 ms. Meanwhile, gather_nd will do a forward pass in about 16 ms, and a backward pass in under 1 ms.
At 1000k indices, take is 4 ms fwd, 400 ms bwd, gather is 170 ms fwd, 1ms bwd. .
So... obvious question... is there a way to get the best of both worlds here? The fast forward pass of take, the fast backward pass of gather_nd?
Is there a better operator for gathering rows from the table? I also tried Embedding - on my test it looked like the best of both worlds, but on the real app, was slow on the backward pass.
The text was updated successfully, but these errors were encountered:
I have an observation and I'm hoping someone can advise.
I have a scenario where I maintain a large table of vectors, so a basic (n,m) array of n vectors of size m. I then have some system that generates indices into this table, and I want to pull rows out of the table at those indices. (more background? A hash table version of NeRF is what we're making).
So, I have a set of indices, and I want to gather the rows out of the table to make use of elsewhere. There's two operators in mxnet that will do the job -
gather_nd
andtake
I could have more than 100k, even 1000k indices:
take
will do a forward pass in less than 1 ms, but a backward pass will take about 45 ms. Meanwhile,gather_nd
will do a forward pass in about 16 ms, and a backward pass in under 1 ms.take
is 4 ms fwd, 400 ms bwd, gather is 170 ms fwd, 1ms bwd. .So... obvious question... is there a way to get the best of both worlds here? The fast forward pass of
take
, the fast backward pass ofgather_nd
?Is there a better operator for gathering rows from the table? I also tried
Embedding
- on my test it looked like the best of both worlds, but on the real app, was slow on the backward pass.The text was updated successfully, but these errors were encountered: