Skip to content

Commit

Permalink
so much faster
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Feb 27, 2018
1 parent ef96f7a commit fff675e
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 10 deletions.
31 changes: 31 additions & 0 deletions benchmark/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import time

import torch
from torch_cluster import sparse_grid_cluster

n = 90000000
s = 1 / 64

print('GPU ===================')

t = time.perf_counter()
pos = torch.cuda.FloatTensor(n, 3).uniform_(0, 1)
size = torch.cuda.FloatTensor([s, s, s])
torch.cuda.synchronize()
print('Init:', time.perf_counter() - t)

t_all = time.perf_counter()
sparse_grid_cluster(pos, size)
torch.cuda.synchronize()
t_all = time.perf_counter() - t_all
print('All:', t_all)

print('CPU ===================')

pos = pos.cpu()
size = size.cpu()

t_all = time.perf_counter()
sparse_grid_cluster(pos, size)
t_all = time.perf_counter() - t_all
print('All:', t_all)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from setuptools import setup, find_packages

__version__ = '0.2.3'
__version__ = '0.2.4'
url = 'https://github.com/rusty1s/pytorch_cluster'

install_requires = ['cffi', 'torch-unique']
Expand Down
2 changes: 1 addition & 1 deletion torch_cluster/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .functions.grid import sparse_grid_cluster, dense_grid_cluster

__version__ = '0.2.3'
__version__ = '0.2.4'

__all__ = ['sparse_grid_cluster', 'dense_grid_cluster', '__version__']
17 changes: 9 additions & 8 deletions torch_cluster/functions/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ def _preprocess(position, size, batch=None, start=None):

# Translate to minimal positive positions if no start was passed.
if start is None:
position = position - position.min(dim=-2, keepdim=True)[0]
else:
min = []
for i in range(position.size(-1)):
min.append(position[:, i].min())
position = position - position.new(min)
elif start != 0:
position = position - start
assert position.min() >= 0, (
'Passed origin resulting in unallowed negative positions')

# If given, append batch to position tensor.
if batch is not None:
Expand All @@ -37,10 +38,10 @@ def _preprocess(position, size, batch=None, start=None):


def _minimal_cluster_size(position, size):
max = position.max(dim=0)[0]
while max.dim() > 1:
max = max.max(dim=0)[0]
cluster_size = (max / size).long() + 1
max = []
for i in range(position.size(-1)):
max.append(position[:, i].max())
cluster_size = (size.new(max) / size).long() + 1
return cluster_size


Expand Down

0 comments on commit fff675e

Please sign in to comment.