Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast codepointOffset #451

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from
138 changes: 138 additions & 0 deletions cbits/codepoint_offset.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@

#include <string.h>
#include <stdint.h>
#include <sys/types.h>
#ifdef __x86_64__
#include <emmintrin.h>
#include <xmmintrin.h>
#endif
#include <stdbool.h>
#include <sys/cdefs.h>

// The following is from FreeBSD's memmem.c
// https://github.com/freebsd/freebsd-src/blob/9921563f43a924d21c7bf43db4a34e724577db95/lib/libc/string/memmem.c

/*-
* SPDX-License-Identifier: MIT
*
* Copyright (c) 2005-2014 Rich Felker, et al, 2022 Alex Mason.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files (the
* "Software"), to deal in the Software without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/

static uint8_t *
twobyte_memmem(const uint8_t *h, size_t hlen, const uint8_t *n)
{
uint16_t nw = n[0] << 8 | n[1], hw = h[0] << 8 | h[1];
for (h += 2, hlen -= 2; hlen; hlen--, hw = hw << 8 | *h++)
if (hw == nw)
return (uint8_t *)h - 2;
return hw == nw ? (uint8_t *)h - 2 : 0;
}

static uint8_t *
threebyte_memmem(const uint8_t *h, size_t hlen, const uint8_t *n)
{
uint32_t nw = (uint32_t)n[0] << 24 | n[1] << 16 | n[2] << 8;
uint32_t hw = (uint32_t)h[0] << 24 | h[1] << 16 | h[2] << 8;
for (h += 3, hlen -= 3; hlen; hlen--, hw = (hw | *h++) << 8)
if (hw == nw)
return (uint8_t *)h - 3;
return hw == nw ? (uint8_t *)h - 3 : 0;
}

static uint8_t *
fourbyte_memmem(const uint8_t *h, size_t hlen, const uint8_t *n)
{
uint32_t nw = (uint32_t)n[0] << 24 | n[1] << 16 | n[2] << 8 | n[3];
uint32_t hw = (uint32_t)h[0] << 24 | h[1] << 16 | h[2] << 8 | h[3];
for (h += 4, hlen -= 4; hlen; hlen--, hw = hw << 8 | *h++)
if (hw == nw)
return (uint8_t *)h - 4;
return hw == nw ? (uint8_t *)h - 4 : 0;
}

static int _hs_codepoint_to_utf8(uint8_t asUtf8[4], uint32_t codepoint)
{

if (codepoint < 0x80)
{
asUtf8[0] = codepoint;
return 1;
}
else if (codepoint < 0x0800)
{
asUtf8[0] = (uint8_t)(((codepoint >> 6) & 0x1F) | 0xC0);
asUtf8[1] = (uint8_t)(((codepoint >> 0) & 0x3F) | 0x80);
return 2;
}
else if (codepoint < 0x10000)
{
asUtf8[0] = (uint8_t)(((codepoint >> 12) & 0x0F) | 0xE0);
asUtf8[1] = (uint8_t)(((codepoint >> 6) & 0x3F) | 0x80);
asUtf8[2] = (uint8_t)(((codepoint >> 0) & 0x3F) | 0x80);
return 3;
}
else
{
asUtf8[0] = (uint8_t)(((codepoint >> 18) & 0x07) | 0xF0);
asUtf8[1] = (uint8_t)(((codepoint >> 12) & 0x3F) | 0x80);
asUtf8[2] = (uint8_t)(((codepoint >> 6) & 0x3F) | 0x80);
asUtf8[3] = (uint8_t)(((codepoint >> 0) & 0x3F) | 0x80);
return 4;
}
}

size_t _hs_offset_of_codepoint(const uint8_t *haystack0, const size_t hoffset, const size_t hlen0, const size_t needle)
{
const uint8_t *haystack = haystack0 + hoffset;
uint8_t *res = NULL;
uint8_t asUtf8[4] = {0};
const int codepointLen = _hs_codepoint_to_utf8(asUtf8, needle);

// Skip to first location that could contain the character.
uint8_t *haystackFirst = (uint8_t *)memchr(haystack, asUtf8[0], hlen0);

if (haystackFirst)
{
const size_t hlen = hlen0 - (haystackFirst - haystack);

switch (codepointLen)
{
case 1:
res = haystackFirst;
break;
case 2:
res = twobyte_memmem(haystackFirst, hlen, asUtf8);
break;
case 3:
res = threebyte_memmem(haystackFirst, hlen, asUtf8);
break;
case 4:
res = fourbyte_memmem(haystackFirst, hlen, asUtf8);
break;
default:
res = NULL;
break;
}
}

return res == NULL ? -1 : (size_t)((uint8_t *)res - haystack);
}
77 changes: 69 additions & 8 deletions src/Data/Text.hs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ module Data.Text
, stripEnd
, splitAt
, breakOn
, breakOnChar
, breakOnEnd
, break
, span
Expand All @@ -157,6 +158,7 @@ module Data.Text
-- $split
, splitOn
, split
, splitOnChar
, chunksOf

-- ** Breaking into lines and words
Expand Down Expand Up @@ -204,6 +206,7 @@ module Data.Text
, unpackCStringAscii#

, measureOff
, codepointOffset
) where

import Prelude (Char, Bool(..), Int, Maybe(..), String,
Expand Down Expand Up @@ -262,6 +265,7 @@ import System.IO.Unsafe (unsafePerformIO)
-- $setup
-- >>> :set -package transformers
-- >>> import Control.Monad.Trans.State
-- >>> import Data.Char (isUpper)
-- >>> import Data.Text
-- >>> import qualified Data.Text as T
-- >>> :seti -XOverloadedStrings
Expand Down Expand Up @@ -411,7 +415,7 @@ instance Data Text where
instance TH.Lift Text where
#if MIN_VERSION_template_haskell(2,16,0)
lift txt = do
let (ptr, len) = unsafePerformIO $ asForeignPtr txt
let (ptr, len) = unsafePerformIO $ asForeignPtr txt
let lenInt = P.fromIntegral len
TH.appE (TH.appE (TH.varE 'unpackCStringLen#) (TH.litE . TH.bytesPrimL $ TH.mkBytes ptr 0 lenInt)) (TH.lift lenInt)
#else
Expand Down Expand Up @@ -1300,6 +1304,19 @@ measureOff !n (Text (A.ByteArray arr) off len) = if len == 0 then 0 else
foreign import ccall unsafe "_hs_text_measure_off" c_measure_off
:: ByteArray# -> CSize -> CSize -> CSize -> IO CSsize

-- | /O(n)/ Finds the byte offset of the first occurrence of @c@ in the @Text@, or
-- '-1' if if can't be found.
codepointOffset :: Text -> Char -> Int
codepointOffset !(Text (A.ByteArray arr) off len) c = if len == 0 then -1 else
cSsizeToInt $ unsafeDupablePerformIO $
c_hs_offset_of_codepoint arr (intToCSize off) (intToCSize len) (intToCSize $ ord c)

-- | The input buffer (arr :: ByteArray#, off :: CSize, len :: CSize)
-- must specify a valid UTF-8 sequence, and the character must be less
-- than 0x10FFFF, these conditions are not checked.
foreign import ccall unsafe "_hs_offset_of_codepoint" c_hs_offset_of_codepoint
::ByteArray# -> CSize -> CSize -> CSize -> IO CSsize

-- | /O(n)/ 'takeEnd' @n@ @t@ returns the suffix remaining after
-- taking @n@ characters from the end of @t@.
--
Expand Down Expand Up @@ -1584,27 +1601,30 @@ splitOn :: HasCallStack
-> [Text]
splitOn pat@(Text _ _ l) src@(Text arr off len)
| l <= 0 = emptyError "splitOn"
| isSingleton pat = split (== unsafeHead pat) src
| isSingleton pat = splitOnChar (unsafeHead pat) src
| otherwise = go 0 (indices pat src)
where
go !s (x:xs) = text arr (s+off) (x-s) : go (x+l) xs
go s _ = [text arr (s+off) (len-s)]
{-# INLINE [1] splitOn #-}

{-# RULES
"TEXT splitOn/singleton -> split/==" [~1] forall c t.
splitOn (singleton c) t = split (==c) t
"TEXT splitOn/singleton -> splitOnChar" [~1] forall c t.
splitOn (singleton c) t = splitOnChar c t
#-}


-- | /O(n)/ Splits a 'Text' into components delimited by separators,
-- where the predicate returns True for a separator element. The
-- resulting components do not contain the separators. Two adjacent
-- separators result in an empty component in the output. eg.
-- separators result in an empty component in the output. To split
-- on a specific character, use @splitOnChar@.
-- eg.
--
-- >>> split (=='a') "aabbaca"
-- ["","","bb","c",""]
-- >>> split isUpper "theQuickBrownFox"
-- ["the","uick","rown","ox"]
--
-- >>> split (=='a') ""
-- >>> split isUpper ""
-- [""]
split :: (Char -> Bool) -> Text -> [Text]
split _ t@(Text _off _arr 0) = [t]
Expand All @@ -1614,6 +1634,32 @@ split p t = loop t
where (# l, s' #) = span_ (not . p) s
{-# INLINE split #-}


{- TODO Fix:
Rule "TEXT split/eq1 -> splitOnChar/==" may never fire
because rule "Class op ==" for ‘==’ might fire first
Probable fix: add phase [n] or [~n] to the competing rulecompile(-Winline-rule-shadowing)
-}
{-# RULES
"TEXT split/eq1 -> splitOnChar/==" [~2] forall c t.
split (== c) t = splitOnChar c t
"TEXT split/eq1 -> splitOnChar/==" [~2] forall c t.
split (c ==) t = splitOnChar c t
#-}


-- | /O(n)/ Splits a 'Text' into components delimited by the given @Char@.
-- The behaviour is the same as @split@ but should be faster than @split (== c)@
--
-- >>> split (=='a') "aabbaca"
-- ["","","bb","c",""]
splitOnChar :: Char -> Text -> [Text]
splitOnChar _ t@(Text _off _arr 0) = [t]
splitOnChar c t = loop t
where loop s | null s' = [l]
| otherwise = l : loop (unsafeTail s')
where ( l, s' ) = breakOnChar c s

-- | /O(n)/ Splits a 'Text' into components of length @k@. The last
-- element may be shorter than the other chunks, depending on the
-- length of the input. Examples:
Expand Down Expand Up @@ -1737,6 +1783,8 @@ filter p = go
-- is the prefix of @haystack@ before @needle@ is matched. The second
-- is the remainder of @haystack@, starting with the match.
--
-- To break on a specific character, use @breakOnChar@
--
-- Examples:
--
-- >>> breakOn "::" "a::b::c"
Expand Down Expand Up @@ -1764,6 +1812,19 @@ breakOn pat src@(Text arr off len)
(x:_) -> (text arr off x, text arr (off+x) (len-x))
{-# INLINE breakOn #-}

-- | /O(n)/ Equivalent to @breakOn (== c)@ but should be faster.
--
-- >>> breakOnChar '/' "foo/bar/"
-- ("foo","/bar/")
--
-- >>> breakOnChar '/' "foobar"
-- ("foobar","")
breakOnChar :: Char -> Text -> (Text, Text)
breakOnChar c src@(Text arr off len) = case codepointOffset src c of
-1 -> (src, empty)
n -> (text arr off n, text arr (off+n) (len-n) )


-- | /O(n+m)/ Similar to 'breakOn', but searches from the end of the
-- string.
--
Expand Down
30 changes: 29 additions & 1 deletion tests/Tests/Properties/Text.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import qualified Data.Text.Internal.Fusion.Common as S
import qualified Data.Text.Internal.Lazy.Fusion as SL
import qualified Data.Text.Internal.Lazy.Search as S (indices)
import qualified Data.Text.Internal.Search as T (indices)
import qualified Data.Text.Internal as TI (Text(..))
import qualified Data.Text.Lazy as TL
import qualified Tests.SlowFunctions as Slow

Expand Down Expand Up @@ -267,6 +268,28 @@ tl_indices_char_drop n c pref suff = map fromIntegral (S.indices s t) === Slow.i
s = TL.singleton c
t = TL.drop n $ pref `TL.append` s `TL.append` suff

t_codepointOffset_exists :: T.Text -> Char -> T.Text -> Property
t_codepointOffset_exists tPrefix target tSuffix =
let cleanPrefix@(TI.Text _ _ len) = T.filter (/= target) tPrefix
in T.codepointOffset (T.append cleanPrefix $ T.cons target tSuffix) target === len

t_codepointOffset_missing :: T.Text -> Char -> Bool
t_codepointOffset_missing t target = T.codepointOffset (T.filter (/= target) t) target == -1

t_breakOnChar_exists :: T.Text -> Char -> T.Text -> Bool
t_breakOnChar_exists tPrefix target tSuffix =
let cleanPrefix = T.filter (/= target) tPrefix
(before, after) = T.breakOnChar target (T.append cleanPrefix $ T.cons target tSuffix)
in before == cleanPrefix && after == T.cons target tSuffix

t_breakOnChar_missing :: T.Text -> Char -> Bool
t_breakOnChar_missing t target =
let filtered = T.filter (/= target) t
in T.breakOnChar target filtered == (filtered,T.empty)

t_breakOnChar_is_break_eq_char :: T.Text -> Char -> Bool
t_breakOnChar_is_break_eq_char t c = T.breakOnChar c t == T.break (== c) t

-- Make a stream appear shorter than it really is, to ensure that
-- functions that consume inaccurately sized streams behave
-- themselves.
Expand Down Expand Up @@ -374,7 +397,12 @@ testText =
testProperty "t_find" t_find,
testProperty "tl_find" tl_find,
testProperty "t_partition" t_partition,
testProperty "tl_partition" tl_partition
testProperty "tl_partition" tl_partition,
testProperty "t_codepointOffset_exists" t_codepointOffset_exists,
testProperty "t_codepointOffset_missing" t_codepointOffset_missing,
testProperty "t_breakOnChar_exists" t_breakOnChar_exists,
testProperty "t_breakOnChar_missing" t_breakOnChar_missing,
testProperty "t_breakOnChar_is_break_eq_char" t_breakOnChar_is_break_eq_char
],

testGroup "indexing" [
Expand Down
1 change: 1 addition & 0 deletions text.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ library
cbits/measure_off.c
cbits/reverse.c
cbits/utils.c
cbits/codepoint_offset.c
hs-source-dirs: src

if flag(simdutf)
Expand Down