diff --git a/src/PostgREST/ApiRequest.hs b/src/PostgREST/ApiRequest.hs index d537030709..6b27ca34e5 100644 --- a/src/PostgREST/ApiRequest.hs +++ b/src/PostgREST/ApiRequest.hs @@ -37,7 +37,7 @@ import Data.Aeson.Types (emptyArray, emptyObject) import Data.List (lookup) import Data.Ranged.Ranges (emptyRange, rangeIntersection, rangeIsEmpty) -import Data.Tree (flatten) +import Data.Tree (Tree (..)) import Network.HTTP.Types.Header (RequestHeaders, hCookie) import Network.HTTP.Types.URI (parseSimpleQuery) import Network.Wai (Request (..)) @@ -117,7 +117,7 @@ data ApiRequest = ApiRequest { , iPayload :: Maybe Payload -- ^ Data sent by client and used for mutation actions , iPreferences :: Preferences.Preferences -- ^ Prefer header values , iQueryParams :: QueryParams.QueryParams - , iColumns :: S.Set FieldName -- ^ parsed colums from &columns parameter and payload + , iColumns :: S.Set (Tree FieldName) -- ^ parsed colums from &columns parameter and payload , iHeaders :: [(ByteString, ByteString)] -- ^ HTTP request headers , iCookies :: [(ByteString, ByteString)] -- ^ Request Cookies , iPath :: ByteString -- ^ Raw request path @@ -237,13 +237,13 @@ getRanges method QueryParams{qsOrder,qsRanges} hdrs isInvalidRange = topLevelRange == emptyRange && not (hasLimitZero limitRange) topLevelRange = fromMaybe allRange $ HM.lookup "limit" ranges -- if no limit is specified, get all the request rows -getPayload :: RequestBody -> MediaType -> QueryParams.QueryParams -> Action -> Either ApiRequestError (Maybe Payload, S.Set FieldName) +getPayload :: RequestBody -> MediaType -> QueryParams.QueryParams -> Action -> Either ApiRequestError (Maybe Payload, S.Set (Tree FieldName)) getPayload reqBody contentMediaType QueryParams{qsColumns} action = do checkedPayload <- if shouldParsePayload then payload else Right Nothing let cols = case (checkedPayload, columns) of - (Just ProcessedJSON{payKeys}, _) -> payKeys - (Just ProcessedUrlEncoded{payKeys}, _) -> payKeys - (Just RawJSON{}, Just cls) -> S.fromList $ foldl (<>) [] (flatten <$> cls) + (Just ProcessedJSON{payKeys}, _) -> S.map (`Node` []) payKeys + (Just ProcessedUrlEncoded{payKeys}, _) -> S.map (`Node` []) payKeys + (Just RawJSON{}, Just cls) -> S.fromList cls _ -> S.empty return (checkedPayload, cols) where diff --git a/src/PostgREST/ApiRequest/QueryParams.hs b/src/PostgREST/ApiRequest/QueryParams.hs index 7ce2c64a97..8d4e431e95 100644 --- a/src/PostgREST/ApiRequest/QueryParams.hs +++ b/src/PostgREST/ApiRequest/QueryParams.hs @@ -863,7 +863,7 @@ pColumnName = lexeme $ do fld <- pFieldName pEnd return fld - where + where pEnd = try (void $ lookAhead (string ")")) <|> try (void $ lookAhead (string ",")) <|> try eof diff --git a/src/PostgREST/Plan.hs b/src/PostgREST/Plan.hs index 43b67a7ae0..06d5530661 100644 --- a/src/PostgREST/Plan.hs +++ b/src/PostgREST/Plan.hs @@ -151,7 +151,7 @@ callReadPlan :: QualifiedIdentifier -> AppConfig -> SchemaCache -> ApiRequest -> callReadPlan identifier conf sCache apiRequest@ApiRequest{iPreferences=Preferences{..},..} invMethod = do let paramKeys = case invMethod of InvRead _ -> S.fromList $ fst <$> qsParams' - Inv -> iColumns + Inv -> S.map rootLabel iColumns proc@Function{..} <- mapLeft ApiRequestError $ findProc identifier paramKeys (preferParameters == Just SingleObject) (dbRoutines sCache) iContentMediaType (invMethod == Inv) let relIdentifier = QualifiedIdentifier pdSchema (fromMaybe pdName $ Routine.funcTableName proc) -- done so a set returning function can embed other relations @@ -923,7 +923,7 @@ mutatePlan mutation qi ApiRequest{iPreferences=Preferences{..}, ..} SchemaCache{ combinedLogic = foldr (addFilterToLogicForest . resolveFilter ctx) logic qsFiltersRoot body = payRaw <$> iPayload -- the body is assumed to be json at this stage(ApiRequest validates) applyDefaults = preferMissing == Just ApplyDefaults - typedColumnsOrError = resolveOrError ctx tbl `traverse` S.toList iColumns + typedColumnsOrError = resolveOrError ctx tbl `traverse` S.toList (S.map rootLabel iColumns) resolveOrError :: ResolverContext -> Maybe Table -> FieldName -> Either ApiRequestError CoercibleField resolveOrError _ Nothing _ = Left NotFound