From 82e548353790430f6c7959152735ac165110238a Mon Sep 17 00:00:00 2001 From: David Wendt Date: Mon, 9 Oct 2023 10:23:06 -0400 Subject: [PATCH] Expose stream parameter in public strings replace APIs --- cpp/include/cudf/strings/replace.hpp | 42 +++++++----- cpp/include/cudf/strings/replace_re.hpp | 28 +++++--- cpp/src/strings/replace/backref_re.cu | 3 +- cpp/src/strings/replace/multi.cu | 3 +- cpp/src/strings/replace/multi_re.cu | 3 +- cpp/src/strings/replace/replace.cu | 8 ++- cpp/src/strings/replace/replace_re.cu | 4 +- cpp/tests/CMakeLists.txt | 4 +- cpp/tests/streams/strings/replace_test.cpp | 80 ++++++++++++++++++++++ 9 files changed, 136 insertions(+), 39 deletions(-) create mode 100644 cpp/tests/streams/strings/replace_test.cpp diff --git a/cpp/include/cudf/strings/replace.hpp b/cpp/include/cudf/strings/replace.hpp index 22818f7542e..2476a41e886 100644 --- a/cpp/include/cudf/strings/replace.hpp +++ b/cpp/include/cudf/strings/replace.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -54,19 +54,21 @@ namespace strings { * * @throw cudf::logic_error if target is an empty string. * - * @param strings Strings column for this operation. - * @param target String to search for within each string. - * @param repl Replacement string if target is found. + * @param input Strings column for this operation + * @param target String to search for within each string + * @param repl Replacement string if target is found * @param maxrepl Maximum times to replace if target appears multiple times in the input string. * Default of -1 specifies replace all occurrences of target in each string. - * @param mr Device memory resource used to allocate the returned column's device memory. - * @return New strings column. + * @param stream CUDA stream used for device memory operations and kernel launches + * @param mr Device memory resource used to allocate the returned column's device memory + * @return New strings column */ std::unique_ptr replace( - strings_column_view const& strings, + strings_column_view const& input, string_scalar const& target, string_scalar const& repl, - int32_t maxrepl = -1, + cudf::size_type maxrepl = -1, + rmm::cuda_stream_view stream = cudf::get_default_stream(), rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @@ -92,21 +94,23 @@ std::unique_ptr replace( * * @throw cudf::logic_error if start is greater than stop. * - * @param strings Strings column for this operation. + * @param input Strings column for this operation. * @param repl Replacement string for specified positions found. * Default is empty string. * @param start Start position where repl will be added. * Default is 0, first character position. * @param stop End position (exclusive) to use for replacement. * Default of -1 specifies the end of each string. - * @param mr Device memory resource used to allocate the returned column's device memory. - * @return New strings column. + * @param stream CUDA stream used for device memory operations and kernel launches + * @param mr Device memory resource used to allocate the returned column's device memory + * @return New strings column */ std::unique_ptr replace_slice( - strings_column_view const& strings, + strings_column_view const& input, string_scalar const& repl = string_scalar(""), size_type start = 0, size_type stop = -1, + rmm::cuda_stream_view stream = cudf::get_default_stream(), rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @@ -141,16 +145,18 @@ std::unique_ptr replace_slice( * if repls is a single string. * @throw cudf::logic_error if targets or repls contain null entries. * - * @param strings Strings column for this operation. - * @param targets Strings to search for in each string. - * @param repls Corresponding replacement strings for target strings. - * @param mr Device memory resource used to allocate the returned column's device memory. - * @return New strings column. + * @param input Strings column for this operation + * @param targets Strings to search for in each string + * @param repls Corresponding replacement strings for target strings + * @param stream CUDA stream used for device memory operations and kernel launches + * @param mr Device memory resource used to allocate the returned column's device memory + * @return New strings column */ std::unique_ptr replace( - strings_column_view const& strings, + strings_column_view const& input, strings_column_view const& targets, strings_column_view const& repls, + rmm::cuda_stream_view stream = cudf::get_default_stream(), rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @} */ // end of doxygen group diff --git a/cpp/include/cudf/strings/replace_re.hpp b/cpp/include/cudf/strings/replace_re.hpp index bc6659835c3..77db2882253 100644 --- a/cpp/include/cudf/strings/replace_re.hpp +++ b/cpp/include/cudf/strings/replace_re.hpp @@ -43,20 +43,22 @@ struct regex_program; * * See the @ref md_regex "Regex Features" page for details on patterns supported by this API. * - * @param strings Strings instance for this operation + * @param input Strings instance for this operation * @param prog Regex program instance * @param replacement The string used to replace the matched sequence in each string. * Default is an empty string. * @param max_replace_count The maximum number of times to replace the matched pattern * within each string. Default replaces every substring that is matched. + * @param stream CUDA stream used for device memory operations and kernel launches * @param mr Device memory resource used to allocate the returned column's device memory * @return New strings column */ std::unique_ptr replace_re( - strings_column_view const& strings, + strings_column_view const& input, regex_program const& prog, string_scalar const& replacement = string_scalar(""), std::optional max_replace_count = std::nullopt, + rmm::cuda_stream_view stream = cudf::get_default_stream(), rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @@ -67,18 +69,20 @@ std::unique_ptr replace_re( * * See the @ref md_regex "Regex Features" page for details on patterns supported by this API. * - * @param strings Strings instance for this operation. - * @param patterns The regular expression patterns to search within each string. - * @param replacements The strings used for replacement. - * @param flags Regex flags for interpreting special characters in the patterns. - * @param mr Device memory resource used to allocate the returned column's device memory. - * @return New strings column. + * @param input Strings instance for this operation + * @param patterns The regular expression patterns to search within each string + * @param replacements The strings used for replacement + * @param flags Regex flags for interpreting special characters in the patterns + * @param stream CUDA stream used for device memory operations and kernel launches + * @param mr Device memory resource used to allocate the returned column's device memory + * @return New strings column */ std::unique_ptr replace_re( - strings_column_view const& strings, + strings_column_view const& input, std::vector const& patterns, strings_column_view const& replacements, regex_flags const flags = regex_flags::DEFAULT, + rmm::cuda_stream_view stream = cudf::get_default_stream(), rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @@ -92,16 +96,18 @@ std::unique_ptr replace_re( * @throw cudf::logic_error if capture index values in `replacement` are not in range 0-99, and also * if the index exceeds the group count specified in the pattern * - * @param strings Strings instance for this operation + * @param input Strings instance for this operation * @param prog Regex program instance * @param replacement The replacement template for creating the output string + * @param stream CUDA stream used for device memory operations and kernel launches * @param mr Device memory resource used to allocate the returned column's device memory * @return New strings column */ std::unique_ptr replace_with_backrefs( - strings_column_view const& strings, + strings_column_view const& input, regex_program const& prog, std::string_view replacement, + rmm::cuda_stream_view stream = cudf::get_default_stream(), rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); } // namespace strings diff --git a/cpp/src/strings/replace/backref_re.cu b/cpp/src/strings/replace/backref_re.cu index 31e06aac72b..74f38cbcc20 100644 --- a/cpp/src/strings/replace/backref_re.cu +++ b/cpp/src/strings/replace/backref_re.cu @@ -148,10 +148,11 @@ std::unique_ptr replace_with_backrefs(strings_column_view const& input, std::unique_ptr replace_with_backrefs(strings_column_view const& strings, regex_program const& prog, std::string_view replacement, + rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); - return detail::replace_with_backrefs(strings, prog, replacement, cudf::get_default_stream(), mr); + return detail::replace_with_backrefs(strings, prog, replacement, stream, mr); } } // namespace strings diff --git a/cpp/src/strings/replace/multi.cu b/cpp/src/strings/replace/multi.cu index 92ace4e7bc7..ee47932100a 100644 --- a/cpp/src/strings/replace/multi.cu +++ b/cpp/src/strings/replace/multi.cu @@ -490,10 +490,11 @@ std::unique_ptr replace(strings_column_view const& input, std::unique_ptr replace(strings_column_view const& strings, strings_column_view const& targets, strings_column_view const& repls, + rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); - return detail::replace(strings, targets, repls, cudf::get_default_stream(), mr); + return detail::replace(strings, targets, repls, stream, mr); } } // namespace strings diff --git a/cpp/src/strings/replace/multi_re.cu b/cpp/src/strings/replace/multi_re.cu index 867b443c036..3375cb7a789 100644 --- a/cpp/src/strings/replace/multi_re.cu +++ b/cpp/src/strings/replace/multi_re.cu @@ -206,10 +206,11 @@ std::unique_ptr replace_re(strings_column_view const& strings, std::vector const& patterns, strings_column_view const& replacements, regex_flags const flags, + rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); - return detail::replace_re(strings, patterns, replacements, flags, cudf::get_default_stream(), mr); + return detail::replace_re(strings, patterns, replacements, flags, stream, mr); } } // namespace strings diff --git a/cpp/src/strings/replace/replace.cu b/cpp/src/strings/replace/replace.cu index acc1502f4d6..a6a14f27dec 100644 --- a/cpp/src/strings/replace/replace.cu +++ b/cpp/src/strings/replace/replace.cu @@ -751,21 +751,23 @@ std::unique_ptr replace_nulls(strings_column_view const& strings, std::unique_ptr replace(strings_column_view const& strings, string_scalar const& target, string_scalar const& repl, - int32_t maxrepl, + cudf::size_type maxrepl, + rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); - return detail::replace(strings, target, repl, maxrepl, cudf::get_default_stream(), mr); + return detail::replace(strings, target, repl, maxrepl, stream, mr); } std::unique_ptr replace_slice(strings_column_view const& strings, string_scalar const& repl, size_type start, size_type stop, + rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); - return detail::replace_slice(strings, repl, start, stop, cudf::get_default_stream(), mr); + return detail::replace_slice(strings, repl, start, stop, stream, mr); } } // namespace strings diff --git a/cpp/src/strings/replace/replace_re.cu b/cpp/src/strings/replace/replace_re.cu index 81ddb937be5..502d5f1a52e 100644 --- a/cpp/src/strings/replace/replace_re.cu +++ b/cpp/src/strings/replace/replace_re.cu @@ -134,11 +134,11 @@ std::unique_ptr replace_re(strings_column_view const& strings, regex_program const& prog, string_scalar const& replacement, std::optional max_replace_count, + rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); - return detail::replace_re( - strings, prog, replacement, max_replace_count, cudf::get_default_stream(), mr); + return detail::replace_re(strings, prog, replacement, max_replace_count, stream, mr); } } // namespace strings diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index ac13c121530..b14c3cd1300 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -633,8 +633,8 @@ ConfigureTest(STREAM_REPLACE_TEST streams/replace_test.cpp STREAM_MODE testing) ConfigureTest(STREAM_SEARCH_TEST streams/search_test.cpp STREAM_MODE testing) ConfigureTest(STREAM_DICTIONARY_TEST streams/dictionary_test.cpp STREAM_MODE testing) ConfigureTest( - STREAM_STRINGS_TEST streams/strings/case_test.cpp streams/strings/find_test.cpp STREAM_MODE - testing + STREAM_STRINGS_TEST streams/strings/case_test.cpp streams/strings/find_test.cpp + streams/strings/replace_test.cpp STREAM_MODE testing ) ConfigureTest(STREAM_SORTING_TEST streams/sorting_test.cpp STREAM_MODE testing) ConfigureTest(STREAM_TEXT_TEST streams/text/ngrams_test.cpp STREAM_MODE testing) diff --git a/cpp/tests/streams/strings/replace_test.cpp b/cpp/tests/streams/strings/replace_test.cpp new file mode 100644 index 00000000000..82fe46d152b --- /dev/null +++ b/cpp/tests/streams/strings/replace_test.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include +#include + +#include + +class StringsReplaceTest : public cudf::test::BaseFixture {}; + +TEST_F(StringsReplaceTest, Replace) +{ + auto input = cudf::test::strings_column_wrapper({"Héllo", "thesé", "tést strings", ""}); + auto view = cudf::strings_column_view(input); + + auto const target = cudf::string_scalar("é", true, cudf::test::get_default_stream()); + auto const repl = cudf::string_scalar(" ", true, cudf::test::get_default_stream()); + cudf::strings::replace(view, target, repl, -1, cudf::test::get_default_stream()); + cudf::strings::replace(view, view, view, cudf::test::get_default_stream()); + cudf::strings::replace_slice(view, repl, 1, 2, cudf::test::get_default_stream()); + + auto const pattern = std::string("[a-z]"); + auto const prog = cudf::strings::regex_program::create(pattern); + cudf::strings::replace_re(view, *prog, repl, 1, cudf::test::get_default_stream()); + + cudf::test::strings_column_wrapper repls({"1", "a", " "}); + cudf::strings::replace_re(view, + {pattern, pattern, pattern}, + cudf::strings_column_view(repls), + cudf::strings::regex_flags::DEFAULT, + cudf::test::get_default_stream()); +} + +TEST_F(StringsReplaceTest, ReplaceRegex) +{ + auto input = cudf::test::strings_column_wrapper({"Héllo", "thesé", "tést strings", ""}); + auto view = cudf::strings_column_view(input); + + auto const repl = cudf::string_scalar(" ", true, cudf::test::get_default_stream()); + auto const pattern = std::string("[a-z]"); + auto const prog = cudf::strings::regex_program::create(pattern); + cudf::strings::replace_re(view, *prog, repl, 1, cudf::test::get_default_stream()); + + cudf::test::strings_column_wrapper repls({"1", "a", " "}); + cudf::strings::replace_re(view, + {pattern, pattern, pattern}, + cudf::strings_column_view(repls), + cudf::strings::regex_flags::DEFAULT, + cudf::test::get_default_stream()); +} + +TEST_F(StringsReplaceTest, ReplaceRegexBackref) +{ + auto input = cudf::test::strings_column_wrapper({"Héllo thesé", "tést strings"}); + auto view = cudf::strings_column_view(input); + + auto const repl_template = std::string("\\2-\\1"); + auto const pattern = std::string("(\\w) (\\w)"); + auto const prog = cudf::strings::regex_program::create(pattern); + cudf::strings::replace_with_backrefs( + view, *prog, repl_template, cudf::test::get_default_stream()); +}