diff --git a/squeal-postgresql/src/Squeal/PostgreSQL/Expression/Array.hs b/squeal-postgresql/src/Squeal/PostgreSQL/Expression/Array.hs index c35e8b74..003d36b0 100644 --- a/squeal-postgresql/src/Squeal/PostgreSQL/Expression/Array.hs +++ b/squeal-postgresql/src/Squeal/PostgreSQL/Expression/Array.hs @@ -37,6 +37,15 @@ module Squeal.PostgreSQL.Expression.Array , unnest , arrAny , arrAll + , arrayAppend + , arrayPrepend + , arrayCat + , arrayPosition + , arrayPositionBegins + , arrayPositions + , arrayRemoveNull + , arrayReplace + , trimArray ) where import Data.String @@ -47,6 +56,7 @@ import qualified Generics.SOP as SOP import Squeal.PostgreSQL.Expression import Squeal.PostgreSQL.Expression.Logic +import Squeal.PostgreSQL.Expression.Null import Squeal.PostgreSQL.Expression.Type import Squeal.PostgreSQL.Query.From.Set import Squeal.PostgreSQL.Render @@ -240,3 +250,42 @@ arrAny -> Expression grp lat with db params from (null ('PGvararray ty2)) -- ^ array -> Condition grp lat with db params from arrAny x (?) xs = x ? (UnsafeExpression $ "ANY" <+> parenthesized (renderSQL xs)) + +arrayAppend :: '[null ('PGvararray ty), ty] ---> null ('PGvararray ty) +arrayAppend = unsafeFunctionN "array_append" + +arrayPrepend :: '[ty, null ('PGvararray ty)] ---> null ('PGvararray ty) +arrayPrepend = unsafeFunctionN "array_prepend" + +arrayCat + :: '[null ('PGvararray ty), null ('PGvararray ty)] + ---> null ('PGvararray ty) +arrayCat = unsafeFunctionN "array_cat" + +arrayPosition :: '[null ('PGvararray ty), ty] ---> 'Null 'PGint8 +arrayPosition = unsafeFunctionN "array_position" + +arrayPositionBegins + :: '[null ('PGvararray ty), ty, null 'PGint8] ---> 'Null 'PGint8 +arrayPositionBegins = unsafeFunctionN "array_position" + +arrayPositions + :: '[null ('PGvararray ty), ty] + ---> null ('PGvararray ('NotNull 'PGint8)) +arrayPositions = unsafeFunctionN "array_positions" + +arrayRemove :: '[null ('PGvararray ty), ty] ---> null ('PGvararray ty) +arrayRemove = unsafeFunctionN "array_remove" + +arrayRemoveNull :: null ('PGvararray ('Null ty)) --> null ('PGvararray ('NotNull ty)) +arrayRemoveNull arr = UnsafeExpression (renderSQL (arrayRemove (arr *: null_))) + +arrayReplace + :: '[null ('PGvararray ty), ty, ty] + ---> null ('PGvararray ty) +arrayReplace = unsafeFunctionN "array_replace" + +trimArray + :: '[null ('PGvararray ty), 'NotNull 'PGint8] + ---> null ('PGvararray ty) +trimArray = unsafeFunctionN "trim_array" diff --git a/squeal-postgresql/src/Squeal/PostgreSQL/Session.hs b/squeal-postgresql/src/Squeal/PostgreSQL/Session.hs index 25c4a85e..71ad1cce 100644 --- a/squeal-postgresql/src/Squeal/PostgreSQL/Session.hs +++ b/squeal-postgresql/src/Squeal/PostgreSQL/Session.hs @@ -44,7 +44,6 @@ import Control.Monad (MonadPlus(..)) import Control.Monad.Base (MonadBase(..)) import Control.Monad.Fix (MonadFix(..)) import Control.Monad.Catch -import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.Morph import Control.Monad.Reader import Control.Monad.Trans.Control (MonadBaseControl(..), MonadTransControl(..)) diff --git a/squeal-postgresql/src/Squeal/PostgreSQL/Session/Decode.hs b/squeal-postgresql/src/Squeal/PostgreSQL/Session/Decode.hs index b52ddb06..e4f55580 100644 --- a/squeal-postgresql/src/Squeal/PostgreSQL/Session/Decode.hs +++ b/squeal-postgresql/src/Squeal/PostgreSQL/Session/Decode.hs @@ -43,6 +43,7 @@ module Squeal.PostgreSQL.Session.Decode , genericProductRow , appendRows , consRow + , ArrayField (..) -- * Decoding Classes , FromValue (..) , FromField (..) @@ -533,6 +534,42 @@ instance {-# OVERLAPPABLE #-} IsLabel fld (MaybeT (DecodeRow row) y) fromLabel = MaybeT . decodeRow $ \(_ SOP.:* bs) -> runDecodeRow (runMaybeT (fromLabel @fld)) bs +{- | Utility for decoding array fields in a `DecodeRow`, +accessed via overloaded labels. +-} +newtype ArrayField row y = ArrayField + { runArrayField + :: StateT Strict.ByteString (Except Strict.Text) y + -> DecodeRow row [y] + } +instance {-# OVERLAPPING #-} + ( KnownSymbol fld + , PG y ~ ty + , arr ~ 'NotNull ('PGvararray ('NotNull ty)) + ) => IsLabel fld (ArrayField (fld ::: arr ': row) y) where + fromLabel = ArrayField $ \yval -> + decodeRow $ \(SOP.K bytesMaybe SOP.:* _) -> do + let + flderr = mconcat + [ "field name: " + , "\"", fromString (symbolVal (SOP.Proxy @fld)), "\"; " + ] + yarr + = devalue + . array + . dimensionArray replicateM + . valueArray + . revalue + $ yval + case bytesMaybe of + Nothing -> Left (flderr <> "encountered unexpected NULL") + Just bytes -> runExcept (evalStateT yarr bytes) +instance {-# OVERLAPPABLE #-} IsLabel fld (ArrayField row y) + => IsLabel fld (ArrayField (field ': row) y) where + fromLabel = ArrayField $ \yval -> + decodeRow $ \(_ SOP.:* bytess) -> + runDecodeRow (runArrayField (fromLabel @fld) yval) bytess + -- | A `GenericRow` constraint to ensure that a Haskell type -- is a record type, -- has a `RowPG`,