Skip to content

Commit

Permalink
Fix build error on rocm4.5/rocm5 on ubuntu18.04
Browse files Browse the repository at this point in the history
* declear new rccl api in paddle/fluid
* Fix build error:free(): invalid pointer by pick the fix:
skarupke/flat_hash_map#26
* Filter to check codestyle for flash_hash_map.h

Signed-off-by: jiajuku <[email protected]>
  • Loading branch information
onepick committed Jul 3, 2023
1 parent e5d0861 commit 37379bf
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 12 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ repos:
exclude: |
(?x)^(
paddle/cinn/.+|
test/cpp/cinn/.+
test/cpp/cinn/.+|
paddle/utils/flat_hash_map.h+
)$
# For CMake files
- repo: local
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/platform/dynload/rccl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,18 @@ RCCL_RAND_ROUTINE_EACH(DEFINE_WRAP);
RCCL_RAND_ROUTINE_EACH_AFTER_2212(DEFINE_WRAP)
#endif

#if NCCL_VERSION_CODE >= 2304
RCCL_RAND_ROUTINE_EACH_AFTER_2304(DEFINE_WRAP)
#endif

#if NCCL_VERSION_CODE >= 2703
RCCL_RAND_ROUTINE_EACH_AFTER_2703(DEFINE_WRAP)
#endif

#if NCCL_VERSION_CODE >= 21100
RCCL_RAND_ROUTINE_EACH_AFTER_21100(DEFINE_WRAP)
#endif

} // namespace dynload
} // namespace platform
} // namespace paddle
12 changes: 12 additions & 0 deletions paddle/fluid/platform/dynload/rccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,25 @@ RCCL_RAND_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_RCCL_WRAP)
RCCL_RAND_ROUTINE_EACH_AFTER_2212(PLATFORM_DECLARE_DYNAMIC_LOAD_RCCL_WRAP)
#endif

#if NCCL_VERSION_CODE >= 2304
#define RCCL_RAND_ROUTINE_EACH_AFTER_2304(__macro) __macro(ncclGetVersion);
RCCL_RAND_ROUTINE_EACH_AFTER_2304(PLATFORM_DECLARE_DYNAMIC_LOAD_RCCL_WRAP)
#endif

#if NCCL_VERSION_CODE >= 2703
#define RCCL_RAND_ROUTINE_EACH_AFTER_2703(__macro) \
__macro(ncclSend); \
__macro(ncclRecv);
RCCL_RAND_ROUTINE_EACH_AFTER_2703(PLATFORM_DECLARE_DYNAMIC_LOAD_RCCL_WRAP)
#endif

#if NCCL_VERSION_CODE >= 21100
#define RCCL_RAND_ROUTINE_EACH_AFTER_21100(__macro) \
__macro(ncclRedOpCreatePreMulSum); \
__macro(ncclRedOpDestroy);
RCCL_RAND_ROUTINE_EACH_AFTER_21100(PLATFORM_DECLARE_DYNAMIC_LOAD_RCCL_WRAP)
#endif

} // namespace dynload
} // namespace platform
} // namespace paddle
26 changes: 15 additions & 11 deletions paddle/utils/flat_hash_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,6 @@ struct sherwood_v3_entry {
sherwood_v3_entry(int8_t distance_from_desired)
: distance_from_desired(distance_from_desired) {}
~sherwood_v3_entry() {}
static sherwood_v3_entry *empty_default_table() {
static sherwood_v3_entry result[min_lookups] = {
{}, {}, {}, {special_end_value}};
return result;
}

bool has_value() const { return distance_from_desired >= 0; }
bool is_empty() const { return distance_from_desired < 0; }
Expand Down Expand Up @@ -664,13 +659,24 @@ class sherwood_v3_table : private EntryAlloc, private Hasher, private Equal {
bool empty() const { return num_elements == 0; }

private:
EntryPointer entries = Entry::empty_default_table();
EntryPointer entries = empty_default_table();
size_t num_slots_minus_one = 0;
typename HashPolicySelector<ArgumentHash>::type hash_policy;
int8_t max_lookups = detailv3::min_lookups - 1;
float _max_load_factor = 0.5f;
size_t num_elements = 0;

EntryPointer empty_default_table() {
EntryPointer result =
AllocatorTraits::allocate(*this, detailv3::min_lookups);
EntryPointer special_end_item =
result + static_cast<ptrdiff_t>(detailv3::min_lookups - 1);
for (EntryPointer it = result; it != special_end_item; ++it)
it->distance_from_desired = -1;
special_end_item->distance_from_desired = Entry::special_end_value;
return result;
}

static int8_t compute_max_lookups(size_t num_buckets) {
int8_t desired = detailv3::log2(num_buckets);
return (std::max)(detailv3::min_lookups, desired);
Expand Down Expand Up @@ -743,15 +749,13 @@ class sherwood_v3_table : private EntryAlloc, private Hasher, private Equal {
void deallocate_data(EntryPointer begin,
size_t num_slots_minus_one,
int8_t max_lookups) {
if (begin != Entry::empty_default_table()) {
AllocatorTraits::deallocate(
*this, begin, num_slots_minus_one + max_lookups + 1);
}
AllocatorTraits::deallocate(
*this, begin, num_slots_minus_one + max_lookups + 1);
}

void reset_to_empty_state() {
deallocate_data(entries, num_slots_minus_one, max_lookups);
entries = Entry::empty_default_table();
entries = empty_default_table();
num_slots_minus_one = 0;
hash_policy.reset();
max_lookups = detailv3::min_lookups - 1;
Expand Down

0 comments on commit 37379bf

Please sign in to comment.