Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Doing expression plot is slow for 1000+ functions #354

Open
young-x-skyee opened this issue Aug 7, 2024 · 3 comments
Open

Doing expression plot is slow for 1000+ functions #354

young-x-skyee opened this issue Aug 7, 2024 · 3 comments
Labels
❓ discussion needed Extra discussion is needed before work can commence ⏱️ optimisation 📈 plotting functionality Any issues related to plotting

Comments

@young-x-skyee
Copy link
Contributor

When I tried to do expression plot using expression_plot for the activation of whisper large model, I noticed that it is very slow with a large memory use (increased slowly) at the beginning and then high CPU use but low memory use. That was where it got stuck.

@young-x-skyee
Copy link
Contributor Author

young-x-skyee commented Aug 9, 2024

I think the problem is inside the for loop of plotting all the show_only functions at around line 465 in plot.py. The loop is actually pretty quick when I plot ~1280 functions (~200 iterations/sec on average) but when there are ~80000 functions the loop is much much slower (~3 iterations/sec).

@caiw
Copy link
Member

caiw commented Sep 20, 2024

Relates to #271

@caiw
Copy link
Member

caiw commented Sep 22, 2024

I attempted to fix this with #378 but the changes I made had next to no effect.

Doing a little more profiling, the main slowdown I can localise is just simply caused by calling pyplot.scatter 1000s of times. If I simply comment out the call to scatter, the main loop mentioned above goes from 40 loops/second up to 200 loops/second for 1000 functions.

The reason it gets progressively slower with additional functions is I think because each call to scatter adds a new distinct pyplot.Collection to the axes, and by the time there are 1000s it appears to do some kind of cache invalidation or something which gets progressively more internally laborious with additional collections. Some of this seems to be autoscaling the axes after each call, which I tried to disable (in #378) but the recommended way in the docs didn't remove the calls to pyplot.autoscale_view which turned up in the profiler call graph, so I'm not sure what else to try. It's conceivable this is a bug in matplotlib.

The only other way I can think which might address this, which could get us at most a 5x speedup (which isn't nothing, but isn't nearly as good as the speedup we got in #374), would be to reduce to a single scatter call, by (1) stacking all the hexels for all the functions into a single matrix and (2) precomputing a matrix of colours from the function indices of each significant hexel, ready to pass to scatter's c argument so all the points end up the right colour. The question is whether we could do that in an efficient enough way with 80k functions to be faster than whatever scatter is doing internally.

Depending on how much this or isn't causing a problem, we could either attempt the above, or close this as a wontfix for now and decide that plotting 80k functions on an expression plot is always going to be kind of slow.

@young-x-skyee @neukym I'd appreciate your thoughts! Have I missed some other source of slowdown?

@caiw caiw removed their assignment Sep 22, 2024
@caiw caiw added ❓ discussion needed Extra discussion is needed before work can commence and removed 📄 .nkg files 💪 enhancement New feature or request labels Sep 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
❓ discussion needed Extra discussion is needed before work can commence ⏱️ optimisation 📈 plotting functionality Any issues related to plotting
Projects
None yet
Development

No branches or pull requests

2 participants