Skip to content

Commit

Permalink
Merge pull request #358 from morphismtech/array-field-explicit-decoding
Browse files Browse the repository at this point in the history
Array utilities
  • Loading branch information
echatav authored Dec 23, 2024
2 parents 73a4022 + 6eed275 commit d0381e8
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 0 deletions.
49 changes: 49 additions & 0 deletions squeal-postgresql/src/Squeal/PostgreSQL/Expression/Array.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ module Squeal.PostgreSQL.Expression.Array
, unnest
, arrAny
, arrAll
, arrayAppend
, arrayPrepend
, arrayCat
, arrayPosition
, arrayPositionBegins
, arrayPositions
, arrayRemoveNull
, arrayReplace
, trimArray
) where

import Data.String
Expand All @@ -47,6 +56,7 @@ import qualified Generics.SOP as SOP

import Squeal.PostgreSQL.Expression
import Squeal.PostgreSQL.Expression.Logic
import Squeal.PostgreSQL.Expression.Null
import Squeal.PostgreSQL.Expression.Type
import Squeal.PostgreSQL.Query.From.Set
import Squeal.PostgreSQL.Render
Expand Down Expand Up @@ -240,3 +250,42 @@ arrAny
-> Expression grp lat with db params from (null ('PGvararray ty2)) -- ^ array
-> Condition grp lat with db params from
arrAny x (?) xs = x ? (UnsafeExpression $ "ANY" <+> parenthesized (renderSQL xs))

arrayAppend :: '[null ('PGvararray ty), ty] ---> null ('PGvararray ty)
arrayAppend = unsafeFunctionN "array_append"

arrayPrepend :: '[ty, null ('PGvararray ty)] ---> null ('PGvararray ty)
arrayPrepend = unsafeFunctionN "array_prepend"

arrayCat
:: '[null ('PGvararray ty), null ('PGvararray ty)]
---> null ('PGvararray ty)
arrayCat = unsafeFunctionN "array_cat"

arrayPosition :: '[null ('PGvararray ty), ty] ---> 'Null 'PGint8
arrayPosition = unsafeFunctionN "array_position"

arrayPositionBegins
:: '[null ('PGvararray ty), ty, null 'PGint8] ---> 'Null 'PGint8
arrayPositionBegins = unsafeFunctionN "array_position"

arrayPositions
:: '[null ('PGvararray ty), ty]
---> null ('PGvararray ('NotNull 'PGint8))
arrayPositions = unsafeFunctionN "array_positions"

arrayRemove :: '[null ('PGvararray ty), ty] ---> null ('PGvararray ty)
arrayRemove = unsafeFunctionN "array_remove"

arrayRemoveNull :: null ('PGvararray ('Null ty)) --> null ('PGvararray ('NotNull ty))
arrayRemoveNull arr = UnsafeExpression (renderSQL (arrayRemove (arr *: null_)))

arrayReplace
:: '[null ('PGvararray ty), ty, ty]
---> null ('PGvararray ty)
arrayReplace = unsafeFunctionN "array_replace"

trimArray
:: '[null ('PGvararray ty), 'NotNull 'PGint8]
---> null ('PGvararray ty)
trimArray = unsafeFunctionN "trim_array"
37 changes: 37 additions & 0 deletions squeal-postgresql/src/Squeal/PostgreSQL/Session/Decode.hs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ module Squeal.PostgreSQL.Session.Decode
, genericProductRow
, appendRows
, consRow
, ArrayField (..)
-- * Decoding Classes
, FromValue (..)
, FromField (..)
Expand Down Expand Up @@ -540,6 +541,42 @@ instance {-# OVERLAPPABLE #-} IsLabel fld (MaybeT (DecodeRow row) y)
fromLabel = MaybeT . decodeRow $ \(_ SOP.:* bs) ->
runDecodeRow (runMaybeT (fromLabel @fld)) bs

{- | Utility for decoding array fields in a `DecodeRow`,
accessed via overloaded labels.
-}
newtype ArrayField row y = ArrayField
{ runArrayField
:: StateT Strict.ByteString (Except Strict.Text) y
-> DecodeRow row [y]
}
instance {-# OVERLAPPING #-}
( KnownSymbol fld
, PG y ~ ty
, arr ~ 'NotNull ('PGvararray ('NotNull ty))
) => IsLabel fld (ArrayField (fld ::: arr ': row) y) where
fromLabel = ArrayField $ \yval ->
decodeRow $ \(SOP.K bytesMaybe SOP.:* _) -> do
let
flderr = mconcat
[ "field name: "
, "\"", fromString (symbolVal (SOP.Proxy @fld)), "\"; "
]
yarr
= devalue
. array
. dimensionArray replicateM
. valueArray
. revalue
$ yval
case bytesMaybe of
Nothing -> Left (flderr <> "encountered unexpected NULL")
Just bytes -> runExcept (evalStateT yarr bytes)
instance {-# OVERLAPPABLE #-} IsLabel fld (ArrayField row y)
=> IsLabel fld (ArrayField (field ': row) y) where
fromLabel = ArrayField $ \yval ->
decodeRow $ \(_ SOP.:* bytess) ->
runDecodeRow (runArrayField (fromLabel @fld) yval) bytess

-- | A `GenericRow` constraint to ensure that a Haskell type
-- is a record type,
-- has a `RowPG`,
Expand Down

0 comments on commit d0381e8

Please sign in to comment.