diff --git a/cpp/include/nvtext/bpe_tokenize.hpp b/cpp/include/nvtext/bpe_tokenize.hpp index b93d93b07c6..c67f4bd8b1c 100644 --- a/cpp/include/nvtext/bpe_tokenize.hpp +++ b/cpp/include/nvtext/bpe_tokenize.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -61,19 +61,6 @@ struct bpe_merge_pairs { rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); ~bpe_merge_pairs(); - - /** - * @brief Returns the number of merge pairs in the table. - * - * @return The number of merge pairs in the table - */ - cudf::size_type get_size(); - /** - * @brief Returns the number of unique merge pairs in the table. - * - * @return The number of unique merge pairs in the table - */ - std::size_t get_map_size(); }; /** diff --git a/cpp/src/text/subword/bpe_tokenizer.cu b/cpp/src/text/subword/bpe_tokenizer.cu index ac55fe76db1..4c4f5b3a4b1 100644 --- a/cpp/src/text/subword/bpe_tokenizer.cu +++ b/cpp/src/text/subword/bpe_tokenizer.cu @@ -80,10 +80,11 @@ __device__ cudf::string_view get_first_token(cudf::string_view const& d_str) * * @see The byte_pair_encoding_fn::operator() function below for details. */ +template struct byte_pair_encoding_fn { cudf::column_device_view const d_merges; cudf::column_device_view const d_strings; - merge_pairs_map_type::device_view const d_map; + MapRefType const d_map; cudf::size_type* d_sizes; // output size of encoded string string_hasher_type const hasher; cudf::size_type* d_byte_indices; @@ -136,17 +137,13 @@ struct byte_pair_encoding_fn { } /** - * @brief Compute the hash over the input strings. + * @brief Look up the pair of strings in the d_map/d_merges * - * The input strings are combined with a space to produce hash for matching - * a merge pair within the `d_map`. - * - * @param lhs First string. - * @param rhs Second string. - * @return The hash value to match with `d_map`. + * @param lhs Left half of the string + * @param rhs Right half of the string + * @return Position of merge pair within d_map */ - __device__ cudf::hash_value_type compute_hash(cudf::string_view const& lhs, - cudf::string_view const& rhs) + __device__ auto get_merge_pair(cudf::string_view const& lhs, cudf::string_view const& rhs) { __shared__ char shmem[48 * 1024]; // max for Pascal auto const total_size = lhs.size_bytes() + rhs.size_bytes() + 1; @@ -154,8 +151,8 @@ struct byte_pair_encoding_fn { // Edge case check. // Empirically found only two merge pair strings that were greater than 70 bytes - // and they both looked like ignorable errors. Double check this analysis with Vibhu. - if (thread_memory_size < total_size) { return 0; } + // and they both looked like ignorable errors. + if (thread_memory_size < total_size) { return d_map.end(); } // build the target string in shared memory char* ptr = &shmem[threadIdx.x * thread_memory_size]; @@ -165,8 +162,8 @@ struct byte_pair_encoding_fn { memcpy(ptr + lhs.size_bytes(), " ", 1); memcpy(ptr + lhs.size_bytes() + 1, rhs.data(), rhs.size_bytes()); - auto const d_hash_str = cudf::string_view(ptr, total_size); - return hasher(d_hash_str); // return the hash for the temp string + auto const d_str = cudf::string_view(ptr, total_size); + return d_map.find(d_str); } /** @@ -233,11 +230,10 @@ struct byte_pair_encoding_fn { auto const rhs = next_substr(itr, end, d_str); if (rhs.empty()) break; // no more adjacent pairs - auto const hash = compute_hash(lhs, rhs); - auto const map_itr = d_map.find(hash, thrust::identity{}); + auto const map_itr = get_merge_pair(lhs, rhs); if (map_itr != d_map.end()) { // found a match; record the rank (and other min_ vars) - auto const rank = static_cast(map_itr->second); + auto const rank = map_itr->second; if (rank < min_rank) { min_rank = rank; min_itr = itr; @@ -354,12 +350,12 @@ std::unique_ptr byte_pair_encoding( bpe_merge_pairs::bpe_merge_pairs_impl const& merge_pairs, rmm::cuda_stream_view stream) { - CUDF_EXPECTS(!merge_pairs.get_merge_pairs().is_empty(), "Merge pairs table must not be empty"); + auto const d_merges = merge_pairs.get_merge_pairs(); + CUDF_EXPECTS(d_merges.size() > 0, "Merge pairs table must not be empty"); // build working vector to hold index values per byte rmm::device_uvector d_byte_indices(input.chars().size(), stream); - auto const d_merges = cudf::column_device_view::create(merge_pairs.get_merge_pairs(), stream); auto const d_strings = cudf::column_device_view::create(input.parent(), stream); auto offsets = cudf::make_numeric_column(cudf::data_type{cudf::type_to_id()}, @@ -369,12 +365,9 @@ std::unique_ptr byte_pair_encoding( rmm::mr::get_current_device_resource()); auto d_offsets = offsets->mutable_view().data(); - byte_pair_encoding_fn fn{*d_merges, - *d_strings, - merge_pairs.get_merge_pairs_map(), - d_offsets, - string_hasher_type{}, - d_byte_indices.data()}; + auto map_ref = merge_pairs.get_merge_pairs_ref(); + byte_pair_encoding_fn fn{ + d_merges, *d_strings, map_ref, d_offsets, string_hasher_type{}, d_byte_indices.data()}; thrust::for_each_n( rmm::exec_policy(stream), thrust::make_counting_iterator(0), input.size(), fn); diff --git a/cpp/src/text/subword/bpe_tokenizer.cuh b/cpp/src/text/subword/bpe_tokenizer.cuh index 0697a9961c7..83aa22aaae9 100644 --- a/cpp/src/text/subword/bpe_tokenizer.cuh +++ b/cpp/src/text/subword/bpe_tokenizer.cuh @@ -21,7 +21,9 @@ #include #include +#include #include +#include #include #include @@ -30,30 +32,84 @@ #include #include +#include namespace nvtext { namespace detail { +using hash_value_type = uint32_t; +using string_hasher_type = cudf::hashing::detail::MurmurHash3_x86_32; + +/** + * @brief Hasher function used for building and using the cuco static-map + * + * This takes advantage of heterogeneous lookup feature in cuco static-map which + * allows inserting with one type (index) and looking up with a different type (string). + */ +struct bpe_hasher { + cudf::column_device_view const d_strings; + string_hasher_type hasher{}; + // used by insert + __device__ hash_value_type operator()(cudf::size_type index) const + { + return hasher(d_strings.element(index)); + } + // used by find + __device__ hash_value_type operator()(cudf::string_view const& s) const { return hasher(s); } +}; + +/** + * @brief Equal function used for building and using the cuco static-map + * + * This takes advantage of heterogeneous lookup feature in cuco static-map which + * allows inserting with one type (index) and looking up with a different type (string). + */ +struct bpe_equal { + cudf::column_device_view const d_strings; + // used by insert + __device__ bool operator()(cudf::size_type lhs, cudf::size_type rhs) const noexcept + { + return d_strings.element(lhs) == d_strings.element(rhs); + } + // used by find + __device__ bool operator()(cudf::size_type lhs, cudf::string_view const& rhs) const noexcept + { + return d_strings.element(lhs) == rhs; + } +}; + using hash_table_allocator_type = rmm::mr::stream_allocator_adaptor>; -using merge_pairs_map_type = cuco::static_map; +using probe_scheme = cuco::experimental::linear_probing<1, bpe_hasher>; -using string_hasher_type = cudf::hashing::detail::MurmurHash3_x86_32; +using merge_pairs_map_type = cuco::experimental::static_map, + cuda::thread_scope_device, + bpe_equal, + probe_scheme, + hash_table_allocator_type>; } // namespace detail +// since column_device_view::create returns is a little more than +// std::unique_ptr this helper simplifies the return type in a more maintainable +// way +using col_device_view = std::invoke_result_t; + struct bpe_merge_pairs::bpe_merge_pairs_impl { std::unique_ptr const merge_pairs; + col_device_view const d_merge_pairs; std::unique_ptr merge_pairs_map; bpe_merge_pairs_impl(std::unique_ptr&& merge_pairs, + col_device_view&& d_merge_pairs, std::unique_ptr&& merge_pairs_map); - auto get_merge_pairs() const { return merge_pairs->view(); } - auto get_merge_pairs_map() const { return merge_pairs_map->get_device_view(); } + auto const get_merge_pairs() const { return *d_merge_pairs; } + auto get_merge_pairs_ref() const { return merge_pairs_map->ref(cuco::experimental::op::find); } }; } // namespace nvtext diff --git a/cpp/src/text/subword/load_merges_file.cu b/cpp/src/text/subword/load_merges_file.cu index b39413af98f..1f1b90b3f49 100644 --- a/cpp/src/text/subword/load_merges_file.cu +++ b/cpp/src/text/subword/load_merges_file.cu @@ -36,23 +36,8 @@ namespace nvtext { namespace detail { - namespace { -struct make_pair_function { - /** - * @brief Hash the merge pair entry - */ - __device__ cuco::pair operator()(cudf::size_type idx) - { - auto const result = _hasher(d_strings.element(idx)); - return cuco::make_pair(result, idx); - } - - string_hasher_type const _hasher; - cudf::column_device_view const d_strings; -}; - /** * @brief Loads a text file of merge-pairs into a strings column. * @@ -101,26 +86,23 @@ std::unique_ptr load_file_to_column(std::string const& filename_me } std::unique_ptr initialize_merge_pairs_map( - cudf::strings_column_view const& input, rmm::cuda_stream_view stream) + cudf::column_device_view const& input, rmm::cuda_stream_view stream) { // Ensure capacity is at least (size/0.7) as documented here: // https://github.com/NVIDIA/cuCollections/blob/6ec8b6dcdeceea07ab4456d32461a05c18864411/include/cuco/static_map.cuh#L179-L182 auto merge_pairs_map = std::make_unique( static_cast(input.size() * 2), // capacity is 2x; - cuco::empty_key{std::numeric_limits::max()}, + cuco::empty_key{-1}, cuco::empty_value{-1}, // empty value is not used + bpe_equal{input}, + probe_scheme{bpe_hasher{input}}, hash_table_allocator_type{default_allocator{}, stream}, stream.value()); - auto d_strings = cudf::column_device_view::create(input.parent(), stream); - make_pair_function pair_func{string_hasher_type{}, *d_strings}; - auto iter = cudf::detail::make_counting_transform_iterator(0, pair_func); + auto iter = cudf::detail::make_counting_transform_iterator( + 0, [] __device__(cudf::size_type idx) { return cuco::make_pair(idx, idx); }); - merge_pairs_map->insert(iter, - iter + input.size(), - thrust::identity{}, - thrust::equal_to{}, - stream.value()); + merge_pairs_map->insert_async(iter, iter + input.size(), stream.value()); return merge_pairs_map; } @@ -128,9 +110,10 @@ std::unique_ptr initialize_merge_pairs_map( std::unique_ptr create_bpe_merge_pairs_impl( std::unique_ptr&& input, rmm::cuda_stream_view stream) { - auto merge_pairs = initialize_merge_pairs_map(cudf::strings_column_view(input->view()), stream); - return std::make_unique(std::move(input), - std::move(merge_pairs)); + auto d_input = cudf::column_device_view::create(input->view(), stream); + auto merge_pairs = initialize_merge_pairs_map(*d_input, stream); + return std::make_unique( + std::move(input), std::move(d_input), std::move(merge_pairs)); } std::unique_ptr create_bpe_merge_pairs_impl( @@ -163,8 +146,12 @@ std::unique_ptr load_merge_pairs_file(std::string const& filena bpe_merge_pairs::bpe_merge_pairs_impl::bpe_merge_pairs_impl( std::unique_ptr&& merge_pairs, + std::unique_ptr>&& + d_merge_pairs, std::unique_ptr&& merge_pairs_map) - : merge_pairs(std::move(merge_pairs)), merge_pairs_map(std::move(merge_pairs_map)) + : merge_pairs(std::move(merge_pairs)), + d_merge_pairs(std::move(d_merge_pairs)), + merge_pairs_map(std::move(merge_pairs_map)) { } @@ -184,7 +171,4 @@ bpe_merge_pairs::bpe_merge_pairs(cudf::strings_column_view const& input, bpe_merge_pairs::~bpe_merge_pairs() = default; -cudf::size_type bpe_merge_pairs::get_size() { return impl->merge_pairs->size(); } -std::size_t bpe_merge_pairs::get_map_size() { return impl->merge_pairs_map->get_size(); } - } // namespace nvtext