Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix build error on rocm4.5/rocm5 on ubuntu18.04 #55076

Merged
merged 2 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ repos:
exclude: |
(?x)^(
paddle/cinn/.+|
test/cpp/cinn/.+
test/cpp/cinn/.+|
paddle/utils/flat_hash_map.h+
onepick marked this conversation as resolved.
Show resolved Hide resolved
)$
# For CMake files
- repo: local
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/platform/dynload/rccl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,18 @@ RCCL_RAND_ROUTINE_EACH(DEFINE_WRAP);
RCCL_RAND_ROUTINE_EACH_AFTER_2212(DEFINE_WRAP)
#endif

#if NCCL_VERSION_CODE >= 2304
RCCL_RAND_ROUTINE_EACH_AFTER_2304(DEFINE_WRAP)
#endif

#if NCCL_VERSION_CODE >= 2703
RCCL_RAND_ROUTINE_EACH_AFTER_2703(DEFINE_WRAP)
#endif

#if NCCL_VERSION_CODE >= 21100
RCCL_RAND_ROUTINE_EACH_AFTER_21100(DEFINE_WRAP)
#endif

} // namespace dynload
} // namespace platform
} // namespace paddle
12 changes: 12 additions & 0 deletions paddle/fluid/platform/dynload/rccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,25 @@ RCCL_RAND_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_RCCL_WRAP)
RCCL_RAND_ROUTINE_EACH_AFTER_2212(PLATFORM_DECLARE_DYNAMIC_LOAD_RCCL_WRAP)
#endif

#if NCCL_VERSION_CODE >= 2304
#define RCCL_RAND_ROUTINE_EACH_AFTER_2304(__macro) __macro(ncclGetVersion);
RCCL_RAND_ROUTINE_EACH_AFTER_2304(PLATFORM_DECLARE_DYNAMIC_LOAD_RCCL_WRAP)
#endif

#if NCCL_VERSION_CODE >= 2703
#define RCCL_RAND_ROUTINE_EACH_AFTER_2703(__macro) \
__macro(ncclSend); \
__macro(ncclRecv);
RCCL_RAND_ROUTINE_EACH_AFTER_2703(PLATFORM_DECLARE_DYNAMIC_LOAD_RCCL_WRAP)
#endif

#if NCCL_VERSION_CODE >= 21100
#define RCCL_RAND_ROUTINE_EACH_AFTER_21100(__macro) \
__macro(ncclRedOpCreatePreMulSum); \
__macro(ncclRedOpDestroy);
RCCL_RAND_ROUTINE_EACH_AFTER_21100(PLATFORM_DECLARE_DYNAMIC_LOAD_RCCL_WRAP)
#endif

} // namespace dynload
} // namespace platform
} // namespace paddle
26 changes: 15 additions & 11 deletions paddle/utils/flat_hash_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,6 @@ struct sherwood_v3_entry {
sherwood_v3_entry(int8_t distance_from_desired)
: distance_from_desired(distance_from_desired) {}
~sherwood_v3_entry() {}
static sherwood_v3_entry *empty_default_table() {
static sherwood_v3_entry result[min_lookups] = {
{}, {}, {}, {special_end_value}};
return result;
}

bool has_value() const { return distance_from_desired >= 0; }
bool is_empty() const { return distance_from_desired < 0; }
Expand Down Expand Up @@ -664,13 +659,24 @@ class sherwood_v3_table : private EntryAlloc, private Hasher, private Equal {
bool empty() const { return num_elements == 0; }

private:
EntryPointer entries = Entry::empty_default_table();
EntryPointer entries = empty_default_table();
size_t num_slots_minus_one = 0;
typename HashPolicySelector<ArgumentHash>::type hash_policy;
int8_t max_lookups = detailv3::min_lookups - 1;
float _max_load_factor = 0.5f;
size_t num_elements = 0;

EntryPointer empty_default_table() {
EntryPointer result =
AllocatorTraits::allocate(*this, detailv3::min_lookups);
EntryPointer special_end_item =
result + static_cast<ptrdiff_t>(detailv3::min_lookups - 1);
for (EntryPointer it = result; it != special_end_item; ++it)
it->distance_from_desired = -1;
special_end_item->distance_from_desired = Entry::special_end_value;
return result;
}

static int8_t compute_max_lookups(size_t num_buckets) {
int8_t desired = detailv3::log2(num_buckets);
return (std::max)(detailv3::min_lookups, desired);
Expand Down Expand Up @@ -743,15 +749,13 @@ class sherwood_v3_table : private EntryAlloc, private Hasher, private Equal {
void deallocate_data(EntryPointer begin,
size_t num_slots_minus_one,
int8_t max_lookups) {
if (begin != Entry::empty_default_table()) {
AllocatorTraits::deallocate(
*this, begin, num_slots_minus_one + max_lookups + 1);
}
AllocatorTraits::deallocate(
*this, begin, num_slots_minus_one + max_lookups + 1);
}

void reset_to_empty_state() {
deallocate_data(entries, num_slots_minus_one, max_lookups);
entries = Entry::empty_default_table();
entries = empty_default_table();
num_slots_minus_one = 0;
hash_policy.reset();
max_lookups = detailv3::min_lookups - 1;
Expand Down