Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change batch_size in sparsify #639

Open
MUCDK opened this issue Dec 14, 2023 · 11 comments
Open

change batch_size in sparsify #639

MUCDK opened this issue Dec 14, 2023 · 11 comments
Assignees

Comments

@MUCDK
Copy link
Collaborator

MUCDK commented Dec 14, 2023

Needed for keeping solver fast (high batch size), but prevent OOM in sparsify (need low batch size), see theislab/cellrank#1146 (comment)

@selmanozleyen
Copy link
Collaborator

hey @giovp @MUCDK, I ran some quick benchmark to capture the peak memory.

From what I checked maximum memory allocated for solve is (batch_size_point_cloud,d), (n,1), or (m,1) and apply_lse_kernel contributes to these allocations. Meanwhile apply uses apply_lse_kernel with vmap so that it's run batch_size_sparse times and stacks its result. Gets n,1 for each apply_lse_kernel and returns (n,batch_size_sparse).

If m=n, these should be the memory complexities in theory:

  • solve memory complexity is O(max(n,batch_size_point_cloud*d))
  • apply memory complexity is O(max(n*batch_size_sparse,batch_size_point_cloud*d)

however as in the linked issue apply tries to allocate (batch_size,d,m). I think this is due to the vmap usage in apply. When I call jax.make_jaxpr(solver)(ot_prob) I get at most 2d array shapes while if I run jax.make_jaxpr(ot.apply)(jnp.eye(19, n)) I get 3d shaped arrays. (you can also run these after solving in this notebook https://github.com/ott-jax/ott/blob/main/docs/tutorials/point_clouds.ipynb)

For the benchmark setup.

I ran with n, m, d = 1400, 1700, 400 and max_iterations=2 for solve.

[33.33%] ··· benchmarks.PointCloud.peakmem_apply1                                ok
[33.33%] ··· =============== ======= =======
             --               batch_size_sp 
             --------------- ---------------
              batch_size_pc    400     600  
             =============== ======= =======
                   400        1.19G   1.52G 
                   120         650M    799M 
             =============== ======= =======

[66.67%] ··· benchmarks.PointCloud.peakmem_apply2                                ok
[66.67%] ··· =============== ======= =======
             --               batch_size_sp 
             --------------- ---------------
              batch_size_pc    400     600  
             =============== ======= =======
                   400        1.99G   2.24G 
                   120        1.03G   1.38G 
             =============== ======= =======

[100.00%] ··· benchmarks.PointCloud.peakmem_solve                                 ok
[100.00%] ··· =============== ====== ======
              --              batch_size_sp
              --------------- -------------
               batch_size_pc   400    600  
              =============== ====== ======
                    400        401M   402M 
                    120        384M   394M 
              =============== ====== ======

Here is the code for benchmark (I ran it with asv run --quick --python=same)
https://gist.github.com/selmanozleyen/70d3ed29aa7841bcaa41f18165f64ab5

@MUCDK
Copy link
Collaborator Author

MUCDK commented Mar 6, 2024

Great, thanks. what is batch_size_sparse?

@MUCDK MUCDK closed this as completed Mar 6, 2024
@MUCDK MUCDK reopened this Mar 6, 2024
@MUCDK
Copy link
Collaborator Author

MUCDK commented Mar 6, 2024

So seems like solve and apply require the same memory?
Could you please check by setting batch_size_sparse to 1?

@selmanozleyen
Copy link
Collaborator

selmanozleyen commented Mar 6, 2024

So seems like solve and apply require the same memory? Could you please check by setting batch_size_sparse to 1?

No, solve doesn't exceed 500mb. Here are the results with batch_size_sp=1, its similar to solve

[33.33%] ··· benchmarks.PointCloud.peakmem_apply1                                ok
[33.33%] ··· =============== ======= ======= ======
             --                  batch_size_sp     
             --------------- ----------------------
              batch_size_pc    400     600     1   
             =============== ======= ======= ======
                   400        1.31G   1.77G   375M 
                   120         656M    795M   377M 
             =============== ======= ======= ======

[66.67%] ··· benchmarks.PointCloud.peakmem_apply2                                ok
[66.67%] ··· =============== ======= ======= ======
             --                  batch_size_sp     
             --------------- ----------------------
              batch_size_pc    400     600     1   
             =============== ======= ======= ======
                   400        1.64G   2.48G   393M 
                   120        1.05G   1.37G   379M 
             =============== ======= ======= ======

[100.00%] ··· benchmarks.PointCloud.peakmem_solve                                 ok
[100.00%] ··· =============== ====== ====== ======
              --                 batch_size_sp    
              --------------- --------------------
               batch_size_pc   400    600     1   
              =============== ====== ====== ======
                    400        403M   403M   390M 
                    120        393M   396M   394M 
              =============== ====== ====== ======


@MUCDK
Copy link
Collaborator Author

MUCDK commented Mar 6, 2024

yeah but seems like apply requires just batch_size_sparse times more memory, which means that apply and solve requires equally much when we apply only to one vector. Hence, vmap is compatible with the batch_size argument, right?

All in all, this means that we require batch_size_sparse times more memory in sparsify , right?

@selmanozleyen
Copy link
Collaborator

yeah but seems like apply requires just batch_size_sparse times more memory, which means that apply and solve requires equally much when we apply only to one vector. Hence, vmap is compatible with the batch_size argument, right?

All in all, this means that we require batch_size_sparse times more memory in sparsify , right?

Yes, that is correct.

@MUCDK
Copy link
Collaborator Author

MUCDK commented Mar 6, 2024

okay, now the question is whether it's faster when we decrease the batch size in ott-jax (i.e. PointCloud), and hence to increase the batch size in sparsify.

Any chance you could benchmark this?

@selmanozleyen
Copy link
Collaborator

yep, here are the results

[50.00%] ··· Running (benchmarks.PointCloud.time_apply1--).
[75.00%] ··· benchmarks.PointCloud.peakmem_apply1                               ok
[75.00%] ··· =============== ======= ======= =======
             --                   batch_size_sp     
             --------------- -----------------------
              batch_size_pc    1200    600     120  
             =============== ======= ======= =======
                   1200        5.5G   5.22G   1.43G 
                   600        5.44G   5.21G    1.4G 
                   120        1.24G    792M    502M 
             =============== ======= ======= =======

[100.00%] ··· benchmarks.PointCloud.time_apply1                                  ok
[100.00%] ··· =============== ============ =========== ===========
              --                         batch_size_sp            
              --------------- ------------------------------------
               batch_size_pc      1200         600         120    
              =============== ============ =========== ===========
                    1200       8.41±0.01s   2.30±0.1s    490±5ms  
                    600         5.24±0s      2.12±0s     491±1ms  
                    120        3.24±0.02s    1.57±0s    422±0.2ms 
              =============== ============ =========== ===========

@MUCDK
Copy link
Collaborator Author

MUCDK commented Mar 8, 2024

Hence, it's faster id we have a large batch size in point_cloud, and a small batch size in sparsify, is this correct?
Thus, there doesn't seem to be a need to change anything in the code, as we would not want to decrease the batch_size in PointCloud, right?

Thus, if we use batch_size=1in sparsify, this would prevent running OOM, and is still the fastest option, even if not great as it takes forever.

Maybe one last thing @selmanozleyen : did you convert the output to a csr_matrix within the for loop to simulate what we are doing in sparsify? As this might require a lot of overhead.

@selmanozleyen
Copy link
Collaborator

selmanozleyen commented Mar 8, 2024

Hence, it's faster id we have a large batch size in point_cloud, and a small batch size in sparsify, is this correct? Thus, there doesn't seem to be a need to change anything in the code, as we would not want to decrease the batch_size in PointCloud, right?

yes for these values of m,n,d.

Thus, if we use batch_size=1in sparsify, this would prevent running OOM, and is still the fastest option, even if not great as it takes forever.

For the whole sparsify method, I am not sure if batch_size=1 would be the best but it seems like the smaller the better. Also is it really that slow? Is there a chance that on these cases the data was on gpu? because the sparsify seems to implicitly copy from gpu to cpu each iteration of the for loop, which is usually very slow. Unless I got something wrong a warning might be a good idea if the data is on gpu.

Maybe one last thing @selmanozleyen : did you convert the output to a csr_matrix within the for loop to simulate what we are doing in sparsify? As this might require a lot of overhead.

No this was just for apply.

@selmanozleyen
Copy link
Collaborator

selmanozleyen commented Mar 12, 2024

Btw I tried this comparison as you told me. The difference is so much and the one with cpu begining takes so much that it times out of 6 minutes while other is done in 5 seconds. I just think that the slowness of the computation on cpu exceeds the cost copying from gpu to cpu. I can share more details once our clusters are faster, I did this quickly on colab

def time_sparsify_cpu_from_start(self, *args, **kwargs):
        for (t1, t2), solution in self.res.solutions.items():
            solution = solution.to(device='cpu')
            solution = solution.sparsify(mode='min_row', batch_size=self.batch_size_sp)
        
def time_sparsify_cpu_implicit(self, *args, **kwargs):
      for (t1, t2), solution in self.res.solutions.items():
            solution = solution.to(device='cuda')
            solution = solution.sparsify(mode='min_row', batch_size=self.batch_size_sp)

@giovp @MUCDK
Update: It timeouts even for most cases (when it's moved to cpu before). The results doesn't accumulate on the gpu ram anyway so I think when applying sparsify the solution should be on gpu and if it OOM's one should reduce the batch_size of sparsify as much as possible. First moving to cpu makes this call way costlier, maybe that is why it takes too much time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants