Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: [WIP] mixing offset and limit with Range header #3578

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions postgrest.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ test-suite spec
Feature.Query.ErrorSpec
Feature.Query.InsertSpec
Feature.Query.JsonOperatorSpec
Feature.Query.LimitOffsetSpec
Feature.Query.LimitedMutationSpec
Feature.Query.MultipleSchemaSpec
Feature.Query.NullsStripSpec
Expand Down
39 changes: 14 additions & 25 deletions src/PostgREST/ApiRequest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
Module : PostgREST.Request.ApiRequest
Description : PostgREST functions to translate HTTP request to a domain type called ApiRequest.
-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecordWildCards #-}
-- TODO: This module shouldn't depend on SchemaCache
module PostgREST.ApiRequest
( ApiRequest(..)
Expand Down Expand Up @@ -35,8 +36,7 @@ import Data.Either.Combinators (mapBoth)
import Control.Arrow ((***))
import Data.Aeson.Types (emptyArray, emptyObject)
import Data.List (lookup)
import Data.Ranged.Ranges (emptyRange, rangeIntersection,
rangeIsEmpty)
import Data.Ranged.Ranges (rangeIsEmpty)
import Network.HTTP.Types.Header (RequestHeaders, hCookie)
import Network.HTTP.Types.URI (parseSimpleQuery)
import Network.Wai (Request (..))
Expand All @@ -50,8 +50,6 @@ import PostgREST.Config (AppConfig (..),
OpenAPIMode (..))
import PostgREST.MediaType (MediaType (..))
import PostgREST.RangeQuery (NonnegRange, allRange,
convertToLimitZeroRange,
hasLimitZero,
rangeRequested)
import PostgREST.SchemaCache (SchemaCache (..))
import PostgREST.SchemaCache.Identifiers (FieldName,
Expand Down Expand Up @@ -111,8 +109,7 @@ data Action
-}
data ApiRequest = ApiRequest {
iAction :: Action -- ^ Action on the resource
, iRange :: HM.HashMap Text NonnegRange -- ^ Requested range of rows within response
, iTopLevelRange :: NonnegRange -- ^ Requested range of rows from the top level
, iRange :: NonnegRange -- ^ Requested range of rows from the selected resource
, iPayload :: Maybe Payload -- ^ Data sent by client and used for mutation actions
, iPreferences :: Preferences.Preferences -- ^ Prefer header values
, iQueryParams :: QueryParams.QueryParams
Expand All @@ -134,12 +131,11 @@ userApiRequest conf req reqBody sCache = do
(schema, negotiatedByProfile) <- getSchema conf hdrs method
act <- getAction resource schema method
qPrms <- first QueryParamError $ QueryParams.parse (actIsInvokeSafe act) $ rawQueryString req
(topLevelRange, ranges) <- getRanges method qPrms hdrs
hRange <- getRange method qPrms hdrs
(payload, columns) <- getPayload reqBody contentMediaType qPrms act
return $ ApiRequest {
iAction = act
, iRange = ranges
, iTopLevelRange = topLevelRange
, iRange = hRange
, iPayload = payload
, iPreferences = Preferences.fromHeaders (configDbTxAllowOverride conf) (dbTimezones sCache) hdrs
, iQueryParams = qPrms
Expand Down Expand Up @@ -217,24 +213,17 @@ getSchema AppConfig{configDbSchemas} hdrs method = do
acceptProfile = T.decodeUtf8 <$> lookupHeader "Accept-Profile"
lookupHeader = flip lookup hdrs

getRanges :: ByteString -> QueryParams -> RequestHeaders -> Either ApiRequestError (NonnegRange, HM.HashMap Text NonnegRange)
getRanges method QueryParams{qsOrder,qsRanges} hdrs
| isInvalidRange = Left $ InvalidRange (if rangeIsEmpty headerRange then LowerGTUpper else NegativeLimit)
| method `elem` ["PATCH", "DELETE"] && not (null qsRanges) && null qsOrder = Left LimitNoOrderError
| method == "PUT" && topLevelRange /= allRange = Left PutLimitNotAllowedError
| otherwise = Right (topLevelRange, ranges)
getRange :: ByteString -> QueryParams -> RequestHeaders -> Either ApiRequestError NonnegRange
getRange method QueryParams{..} hdrs
| rangeIsEmpty headerRange = Left $ InvalidRange LowerGTUpper -- A Range is empty unless its upper boundary is GT its lower boundary
| method `elem` ["PATCH","DELETE"] && not (null qsLimit) && null qsOrder = Left LimitNoOrderError
| method == "PUT" && offsetLimitPresent = Left PutLimitNotAllowedError
| otherwise = Right headerRange
where
-- According to the RFC (https://www.rfc-editor.org/rfc/rfc9110.html#name-range),
-- the Range header must be ignored for all methods other than GET
headerRange = if method == "GET" then rangeRequested hdrs else allRange
limitRange = fromMaybe allRange (HM.lookup "limit" qsRanges)
headerAndLimitRange = rangeIntersection headerRange limitRange
-- Bypass all the ranges and send only the limit zero range (0 <= x <= -1) if
-- limit=0 is present in the query params (not allowed for the Range header)
ranges = HM.insert "limit" (convertToLimitZeroRange limitRange headerAndLimitRange) qsRanges
-- The only emptyRange allowed is the limit zero range
isInvalidRange = topLevelRange == emptyRange && not (hasLimitZero limitRange)
topLevelRange = fromMaybe allRange $ HM.lookup "limit" ranges -- if no limit is specified, get all the request rows
offsetLimitPresent = not (null qsOffset && null qsLimit)

getPayload :: RequestBody -> MediaType -> QueryParams.QueryParams -> Action -> Either ApiRequestError (Maybe Payload, S.Set FieldName)
getPayload reqBody contentMediaType QueryParams{qsColumns} action = do
Expand Down
8 changes: 4 additions & 4 deletions src/PostgREST/ApiRequest/Preferences.hs
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ fromHeaders allowTxDbOverride acceptedTzNames headers =
prefMap :: ToHeaderValue a => [a] -> Map.Map ByteString a
prefMap = Map.fromList . fmap (\pref -> (toHeaderValue pref, pref))

prefAppliedHeader :: Preferences -> Maybe HTTP.Header
prefAppliedHeader Preferences {preferResolution, preferRepresentation, preferParameters, preferCount, preferTransaction, preferMissing, preferHandling, preferTimezone, preferMaxAffected } =
prefAppliedHeader :: Bool -> Preferences -> Maybe HTTP.Header
prefAppliedHeader rangeHdPresent Preferences {preferResolution, preferRepresentation, preferParameters, preferCount, preferTransaction, preferMissing, preferHandling, preferTimezone, preferMaxAffected } =
if null prefsVals
then Nothing
else Just (HTTP.hPreferenceApplied, combined)
Expand All @@ -190,7 +190,7 @@ prefAppliedHeader Preferences {preferResolution, preferRepresentation, preferPar
, toHeaderValue <$> preferMissing
, toHeaderValue <$> preferRepresentation
, toHeaderValue <$> preferParameters
, toHeaderValue <$> preferCount
, toHeaderValue <$> (if rangeHdPresent then preferCount else Nothing)
, toHeaderValue <$> preferTransaction
, toHeaderValue <$> preferHandling
, toHeaderValue <$> preferTimezone
Expand Down Expand Up @@ -254,7 +254,7 @@ instance ToHeaderValue PreferCount where

shouldCount :: Maybe PreferCount -> Bool
shouldCount prefCount =
prefCount == Just ExactCount || prefCount == Just EstimatedCount
prefCount `elem` [Just ExactCount, Just PlannedCount, Just EstimatedCount]

-- | Whether to commit or roll back transactions.
data PreferTransaction
Expand Down
129 changes: 68 additions & 61 deletions src/PostgREST/ApiRequest/QueryParams.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
module PostgREST.ApiRequest.QueryParams
( parse
, QueryParams(..)
, pRequestRange
, pTreePath
) where

import qualified Data.ByteString.Char8 as BS
import qualified Data.HashMap.Strict as HM
import qualified Data.List as L
import qualified Data.Set as S
import qualified Data.Text as T
Expand All @@ -22,42 +21,42 @@ import qualified Network.HTTP.Base as HTTP
import qualified Network.HTTP.Types.URI as HTTP
import qualified Text.ParserCombinators.Parsec as P

import Control.Arrow ((***))
import Data.Either.Combinators (mapLeft)
import Data.List (init, last)
import Data.Ranged.Boundaries (Boundary (..))
import Data.Ranged.Ranges (Range (..))
import Data.Tree (Tree (..))
import Text.Parsec.Error (errorMessages,
showErrorMessages)
import Text.ParserCombinators.Parsec (GenParser, ParseError, Parser,
anyChar, between, char, choice,
digit, eof, errorPos, letter,
lookAhead, many1, noneOf,
notFollowedBy, oneOf,
optionMaybe, sepBy, sepBy1,
string, try, (<?>))

import PostgREST.RangeQuery (NonnegRange, allRange,
rangeGeq, rangeLimit,
rangeOffset, restrictRange)
import Control.Arrow ((***))
import Data.Either.Combinators (mapLeft)
import Data.List (init, last)
import Data.Tree (Tree (..))
import PostgREST.ApiRequest.Types (AggregateFunction (..),
EmbedParam (..), EmbedPath,
Field, Filter (..),
FtsOperator (..), Hint,
JoinType (..),
JsonOperand (..),
JsonOperation (..),
JsonPath, ListVal,
LogicOperator (..),
LogicTree (..), OpExpr (..),
OpQuantifier (..),
Operation (..),
OrderDirection (..),
OrderNulls (..),
OrderTerm (..),
QPError (..),
QuantOperator (..),
SelectItem (..),
SimpleOperator (..),
SingleVal, TrileanVal (..))
import PostgREST.SchemaCache.Identifiers (FieldName)

import PostgREST.ApiRequest.Types (AggregateFunction (..),
EmbedParam (..), EmbedPath, Field,
Filter (..), FtsOperator (..),
Hint, JoinType (..),
JsonOperand (..),
JsonOperation (..), JsonPath,
ListVal, LogicOperator (..),
LogicTree (..), OpExpr (..),
OpQuantifier (..), Operation (..),
OrderDirection (..),
OrderNulls (..), OrderTerm (..),
QPError (..), QuantOperator (..),
SelectItem (..),
SimpleOperator (..), SingleVal,
TrileanVal (..))
import Text.Parsec.Error (errorMessages,
showErrorMessages)
import Text.ParserCombinators.Parsec (GenParser, ParseError,
Parser, anyChar, between,
char, choice, digit, eof,
errorPos, letter, lookAhead,
many1, noneOf,
notFollowedBy, oneOf,
optionMaybe, sepBy, sepBy1,
string, try, (<?>))
import Text.Read (read)

import Protolude hiding (Sum, try)

Expand All @@ -67,8 +66,10 @@ data QueryParams =
-- ^ Canonical representation of the query params, sorted alphabetically
, qsParams :: [(Text, Text)]
-- ^ Parameters for RPC calls
, qsRanges :: HM.HashMap Text (Range Integer)
-- ^ Ranges derived from &limit and &offset params
, qsOffset :: [(EmbedPath, Integer)]
-- ^ &offset parameter
, qsLimit :: [(EmbedPath, Integer)]
-- ^ &limit parameter
, qsOrder :: [(EmbedPath, [OrderTerm])]
-- ^ &order parameters for each level
, qsLogic :: [(EmbedPath, LogicTree)]
Expand Down Expand Up @@ -115,6 +116,8 @@ parse :: Bool -> ByteString -> Either QPError QueryParams
parse isRpcRead qs = do
rOrd <- pRequestOrder `traverse` order
rLogic <- pRequestLogicTree `traverse` logic
rOffset <- pRequestOffset `traverse` offset
rLimit <- pRequestLimit `traverse` limit
rCols <- pRequestColumns columns
rSel <- pRequestSelect select
(rFlts, params) <- L.partition hasOp <$> pRequestFilter isRpcRead `traverse` filters
Expand All @@ -125,7 +128,7 @@ parse isRpcRead qs = do
params' = mapMaybe (\case {(_, Filter (fld, _) (NoOpExpr v)) -> Just (fld,v); _ -> Nothing}) params
rFltsRoot' = snd <$> rFltsRoot

return $ QueryParams canonical params' ranges rOrd rLogic rCols rSel rFlts rFltsRoot' rFltsNotRoot rFltsFields rOnConflict
return $ QueryParams canonical params' rOffset rLimit rOrd rLogic rCols rSel rFlts rFltsRoot' rFltsNotRoot rFltsFields rOnConflict
where
hasRootFilter, hasOp :: (EmbedPath, Filter) -> Bool
hasRootFilter ([], _) = True
Expand All @@ -138,9 +141,8 @@ parse isRpcRead qs = do
onConflict = lookupParam "on_conflict"
columns = lookupParam "columns"
order = filter (endingIn ["order"] . fst) nonemptyParams
limits = filter (endingIn ["limit"] . fst) nonemptyParams
-- Replace .offset ending with .limit to be able to match those params later in a map
offsets = first (replaceLast "limit") <$> filter (endingIn ["offset"] . fst) nonemptyParams
offset = filter (endingIn ["offset"] . fst) nonemptyParams
limit = filter (endingIn ["limit"] . fst) nonemptyParams
lookupParam :: Text -> Maybe Text
Copy link
Member

@laurenceisla laurenceisla Jun 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should still allow doing something like

/clients?select=*,projects(*)&projects.limit=1

I think this change breaks that.

lookupParam needle = toS <$> join (L.lookup needle qParams)
nonemptyParams = mapMaybe (\(k, v) -> (k,) <$> v) qParams
Expand All @@ -155,7 +157,7 @@ parse isRpcRead qs = do
. map (join (***) BS.unpack . second (fromMaybe mempty))
$ qString

endingIn:: [Text] -> Text -> Bool
endingIn :: [Text] -> Text -> Bool
endingIn xx key = lastWord `elem` xx
where lastWord = L.last $ T.split (== '.') key

Expand All @@ -164,21 +166,6 @@ parse isRpcRead qs = do
reserved = ["select", "columns", "on_conflict"]
reservedEmbeddable = ["order", "limit", "offset", "and", "or"]

replaceLast x s = T.intercalate "." $ L.init (T.split (=='.') s) <> [x]

ranges :: HM.HashMap Text (Range Integer)
ranges = HM.unionWith f limitParams offsetParams
where
f rl ro = Range (BoundaryBelow o) (BoundaryAbove $ o + l - 1)
where
l = fromMaybe 0 $ rangeLimit rl
o = rangeOffset ro

limitParams =
HM.fromList [(k, restrictRange (readMaybe v) allRange) | (k,v) <- limits]

offsetParams =
HM.fromList [(k, maybe allRange rangeGeq (readMaybe v)) | (k,v) <- offsets]

simpleOperator :: Parser SimpleOperator
simpleOperator =
Expand Down Expand Up @@ -243,11 +230,19 @@ pRequestOrder (k, v) = mapError $ (,) <$> path <*> ord'
path = fst <$> treePath
ord' = P.parse pOrder ("failed to parse order (" ++ toS v ++ ")") $ toS v

pRequestRange :: (Text, NonnegRange) -> Either QPError (EmbedPath, NonnegRange)
pRequestRange (k, v) = mapError $ (,) <$> path <*> pure v
pRequestOffset :: (Text, Text) -> Either QPError (EmbedPath, Integer)
pRequestOffset (k,v) = mapError $ (,) <$> path <*> int
where
treePath = P.parse pTreePath ("failed to parse tree path (" ++ toS k ++ ")") $ toS k
path = fst <$> treePath
int = P.parse pInt ("failed to parse offset parameter (" <> toS v <> ")") $ toS v

pRequestLimit :: (Text, Text) -> Either QPError (EmbedPath, Integer)
pRequestLimit (k,v) = mapError $ (,) <$> path <*> int
where
treePath = P.parse pTreePath ("failed to parse tree path (" ++ toS k ++ ")") $ toS k
path = fst <$> treePath
int = P.parse pInt ("failed to parse limit parameter (" <> toS v <> ")") $ toS v

pRequestLogicTree :: (Text, Text) -> Either QPError (EmbedPath, LogicTree)
pRequestLogicTree (k, v) = mapError $ (,) <$> embedPath <*> logicTree
Expand Down Expand Up @@ -842,6 +837,18 @@ pLogicPath = do
notOp = "not." <> op
return (filter (/= "not") (init path), if "not" `elem` path then notOp else op)

pInt :: Parser Integer
pInt = pPosInt <|> pNegInt
where
pPosInt :: Parser Integer
pPosInt = many1 digit <&> read

pNegInt :: Parser Integer
pNegInt = do
_ <- char '-'
n <- many1 digit
return ((-1) * read n)

pColumns :: Parser [FieldName]
pColumns = pFieldName `sepBy1` lexeme (char ',')

Expand Down
3 changes: 1 addition & 2 deletions src/PostgREST/ApiRequest/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ data RaiseError
| NoDetail
deriving Show
data RangeError
= NegativeLimit
| LowerGTUpper
= LowerGTUpper
| OutOfBounds Text Text
deriving Show

Expand Down
Loading
Loading