Skip to content

Commit

Permalink
graclus weight gpu bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Apr 18, 2018
1 parent 7985cdd commit 13dabd4
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions aten/THC/THCNumerics.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
template<typename T>
struct THCNumerics {
static inline __host__ __device__ T div(T a, T b) { return a / b; }
static inline __host__ __device__ bool gt(T a, T b) { return a > b; }
static inline __host__ __device__ bool gte(T a, T b) { return a >= b; }
};

#ifdef CUDA_HALF_TENSOR
template<>
struct THCNumerics<half> {
static inline __host__ __device__ half div(half a, half b) { return f2h(h2f(a) / h2f(b)); }
static inline __host__ __device__ bool gt(half a, half b) { return h2f(a) > h2f(b); }
static inline __host__ __device__ bool gte(half a, half b) { return h2f(a) >= h2f(b); }
};
#endif // CUDA_HALF_TENSOR

Expand Down
2 changes: 1 addition & 1 deletion aten/THC/THCPropose.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ __global__ void weightedProposeKernel(int64_t *color, int64_t *prop, int64_t *ro
tmp = weight[e];
if (isDead && color[c] < 0) { isDead = false; } // Unmatched neighbor found.
// Find maximum weighted red neighbor.
if (color[c] == -2 && THCNumerics<T>::gt(tmp, maxWeight)) {
if (color[c] == -2 && THCNumerics<T>::gte(tmp, maxWeight)) {
matchedValue = c;
maxWeight = tmp;
}
Expand Down
2 changes: 1 addition & 1 deletion aten/THC/THCResponse.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ __global__ void weightedResponseKernel(int64_t *color, int64_t *prop, int64_t *r
tmp = weight[e];
if (isDead && color[c] < 0) { isDead = false; } // Unmatched neighbor found.
// Find maximum weighted blue neighbor, who proposed to i.
if (color[c] == -1 && prop[c] == i && THCNumerics<T>::gt(tmp, maxWeight)) {
if (color[c] == -1 && prop[c] == i && THCNumerics<T>::gte(tmp, maxWeight)) {
matchedValue = c;
maxWeight = tmp;
}
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from setuptools import setup, find_packages

__version__ = '1.0.2'
__version__ = '1.0.3'
url = 'https://github.com/rusty1s/pytorch_cluster'

install_requires = ['cffi']
Expand Down
2 changes: 1 addition & 1 deletion torch_cluster/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .graclus import graclus_cluster
from .grid import grid_cluster

__version__ = '1.0.2'
__version__ = '1.0.3'

__all__ = ['graclus_cluster', 'grid_cluster', '__version__']

0 comments on commit 13dabd4

Please sign in to comment.