From 13dabd40ded3aca827daa85560ba18a5701e9679 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 18 Apr 2018 15:13:56 +0200 Subject: [PATCH] graclus weight gpu bugfix --- aten/THC/THCNumerics.cuh | 4 ++-- aten/THC/THCPropose.cuh | 2 +- aten/THC/THCResponse.cuh | 2 +- setup.py | 2 +- torch_cluster/__init__.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/aten/THC/THCNumerics.cuh b/aten/THC/THCNumerics.cuh index 00a212ae..b2639bc0 100644 --- a/aten/THC/THCNumerics.cuh +++ b/aten/THC/THCNumerics.cuh @@ -16,14 +16,14 @@ template struct THCNumerics { static inline __host__ __device__ T div(T a, T b) { return a / b; } - static inline __host__ __device__ bool gt(T a, T b) { return a > b; } + static inline __host__ __device__ bool gte(T a, T b) { return a >= b; } }; #ifdef CUDA_HALF_TENSOR template<> struct THCNumerics { static inline __host__ __device__ half div(half a, half b) { return f2h(h2f(a) / h2f(b)); } - static inline __host__ __device__ bool gt(half a, half b) { return h2f(a) > h2f(b); } + static inline __host__ __device__ bool gte(half a, half b) { return h2f(a) >= h2f(b); } }; #endif // CUDA_HALF_TENSOR diff --git a/aten/THC/THCPropose.cuh b/aten/THC/THCPropose.cuh index 073a7696..116e3d11 100644 --- a/aten/THC/THCPropose.cuh +++ b/aten/THC/THCPropose.cuh @@ -32,7 +32,7 @@ __global__ void weightedProposeKernel(int64_t *color, int64_t *prop, int64_t *ro tmp = weight[e]; if (isDead && color[c] < 0) { isDead = false; } // Unmatched neighbor found. // Find maximum weighted red neighbor. - if (color[c] == -2 && THCNumerics::gt(tmp, maxWeight)) { + if (color[c] == -2 && THCNumerics::gte(tmp, maxWeight)) { matchedValue = c; maxWeight = tmp; } diff --git a/aten/THC/THCResponse.cuh b/aten/THC/THCResponse.cuh index 7a5f9753..e639f154 100644 --- a/aten/THC/THCResponse.cuh +++ b/aten/THC/THCResponse.cuh @@ -35,7 +35,7 @@ __global__ void weightedResponseKernel(int64_t *color, int64_t *prop, int64_t *r tmp = weight[e]; if (isDead && color[c] < 0) { isDead = false; } // Unmatched neighbor found. // Find maximum weighted blue neighbor, who proposed to i. - if (color[c] == -1 && prop[c] == i && THCNumerics::gt(tmp, maxWeight)) { + if (color[c] == -1 && prop[c] == i && THCNumerics::gte(tmp, maxWeight)) { matchedValue = c; maxWeight = tmp; } diff --git a/setup.py b/setup.py index 9c954945..ae149aa1 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup, find_packages -__version__ = '1.0.2' +__version__ = '1.0.3' url = 'https://github.com/rusty1s/pytorch_cluster' install_requires = ['cffi'] diff --git a/torch_cluster/__init__.py b/torch_cluster/__init__.py index f3205a0e..85de5cb7 100644 --- a/torch_cluster/__init__.py +++ b/torch_cluster/__init__.py @@ -1,6 +1,6 @@ from .graclus import graclus_cluster from .grid import grid_cluster -__version__ = '1.0.2' +__version__ = '1.0.3' __all__ = ['graclus_cluster', 'grid_cluster', '__version__']