Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

COPDS-1541 from tokenized to classical pagination #73

Merged
merged 8 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 69 additions & 149 deletions cads_catalogue_api_service/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import datetime
import urllib
from typing import Any, Type

Expand All @@ -28,147 +26,75 @@
import stac_fastapi.types
import stac_fastapi.types.core
import stac_pydantic
from dateutil import parser

from . import config, database, dependencies, exceptions, models, search_utils
from . import (
config,
database,
dependencies,
exceptions,
extensions,
models,
search_utils,
)
from .fastapisessionmaker import FastAPISessionMaker


def decode_base64(encoded: str) -> str:
encoded_bytes = encoded.encode("ascii")
decoded_bytes = base64.b64decode(encoded_bytes)
decoded_str = decoded_bytes.decode("ascii")
return decoded_str


def encode_base64(decoded: str) -> str:
decoded_bytes = decoded.encode("ascii")
encoded_bytes = base64.b64encode(decoded_bytes)
encoded_str = encoded_bytes.decode("ascii")
return encoded_str


def encode_cursor(plain_cursor: Any) -> str:
"""Encode the cursor.

Encoding is based on the type of the plain cursor value
"""
encoded = None
match type(plain_cursor): # noqa: E999
case datetime.datetime:
encoded = encode_base64(plain_cursor.astimezone().isoformat())
case _:
encoded = encode_base64(str(plain_cursor))
return encoded


def decode_cursor(encoded_cursor: str, sortby: str) -> Any:
"""Decode the cursor to the original entity."""
decoded: str | datetime.datetime | None = None
match sortby:
case "update":
decoded = parser.parse(decode_base64(encoded_cursor))
case _:
decoded = decode_base64(encoded_cursor)
return decoded


def get_sorting_clause(
model: type[cads_catalogue.database.Resource], sort: str, inverse: bool
model: type[cads_catalogue.database.Resource], sort: str
) -> dict | tuple:
"""Get the sorting clause."""
supported_sorts = {
"update": (
model.resource_update,
sqlalchemy.desc if not inverse else sqlalchemy.asc,
sqlalchemy.desc,
),
"title": (model.title, sqlalchemy.asc if not inverse else sqlalchemy.desc),
"id": (model.resource_uid, sqlalchemy.asc if not inverse else sqlalchemy.desc),
"title": (model.title, sqlalchemy.asc),
"id": (model.resource_uid, sqlalchemy.asc),
}
return supported_sorts.get(sort) or supported_sorts["update"]


DEFINED_SORT_CRITERIA = {
"update": ("__le__", "__gt__"),
"title": ("__le__", "__gt__"),
"id": ("__ge__", "__lt__"),
}


def get_cursor_compare_criteria(sortby: str, back: bool = False) -> str:
"""Generate the proper cursor based on sorting criteria."""
compare_criteria = DEFINED_SORT_CRITERIA.get(sortby)
if not compare_criteria:
return "__ge__" if not back else "__lt__"
return compare_criteria[0 if not back else 1]


def apply_sorting(
def apply_sorting_and_limit(
search: sqlalchemy.orm.Query,
sortby: str,
cursor: str | None,
page: int,
limit: int,
inverse: bool = False,
q: str | None = "",
):
"""Apply sortby to the running query.

The sorting algorithm influences how pagination is build.
Pagination is based on cursor: see https://use-the-index-luke.com/no-offset
"""
sorting_clause = get_sorting_clause(
cads_catalogue.database.Resource, sortby, inverse
)
"""Apply sortby and limit to the running query."""
sorting_clause = get_sorting_clause(cads_catalogue.database.Resource, sortby)
sort_by, sort_order_fn = sorting_clause

if sortby != "relevance":
if sortby == "relevance" and q:
# generate sorting by relevance based on input
search = search.order_by(search_utils.fulltext_order_by(q))
else:
search = search.order_by(sort_order_fn(sort_by))
get_cursor_direction = get_cursor_compare_criteria(sortby, inverse)
# cursor meaning is based on the sorting criteria
if cursor:
sort_expr = getattr(sort_by, get_cursor_direction)(
decode_cursor(cursor, sortby)
)
search = search.filter(*(sort_expr,))

# limit is +1 for getting the next page
search = search.limit(limit + 1)
search = search.offset(page * limit).limit(limit)

return search, sort_by
return search


def get_next_prev_links(
collections: list,
sort_by,
cursor: str | None,
sortby: str,
page: int,
limit: int,
back: bool = False,
count: int,
) -> dict[str, Any]:
"""Generate a prev/next links array.

# See https://github.com/radiantearth/stac-api-spec/tree/main/item-search#pagination
"""
links = {}

if len(collections) <= limit:
results = collections
else:
results = collections[:-1]

# Next
if len(collections) > limit or back:
if page * limit + limit < count:
# We need a next link, as we have more records to explore
next_cursor = cursor if back else getattr(collections[-1], sort_by.key)
links["next"] = {"cursor": encode_cursor(next_cursor)}
links["next"] = dict(page=page + 1, limit=limit, sortby=sortby)
# Prev
if cursor:
# We have a cursor, so we provide a back link
# NOTE: this is not perfect
# The back link is always present because we don't know anything about the previous page
back_cursor = getattr((results[0] if not back else results[-1]), sort_by.key)
links["prev"] = {
"cursor": encode_cursor(back_cursor),
"back": "true",
}
if page > 0:
links["prev"] = dict(page=page - 1, limit=limit, sortby=sortby)
return links


Expand Down Expand Up @@ -355,10 +281,10 @@ def lookup_id(
# avoid loading datasets from other portals, to block URL manipulation/pollution
search = search.filter(record.portal.in_(portals))
row = search.one()
except sqlalchemy.orm.exc.NoResultFound:
except sqlalchemy.orm.exc.NoResultFound as exc:
raise stac_fastapi.types.errors.NotFoundError(
f"{record.__name__} {id} not found"
)
) from exc
return row


Expand All @@ -368,7 +294,7 @@ def get_active_message(
filter_types=["warning", "critical"],
) -> models.Message | None:
"""Return the latest active message for a dataset."""
messages = (
message = (
session.query(cads_catalogue.database.Message)
.join(cads_catalogue.database.Message.resources)
.where(
Expand All @@ -377,17 +303,17 @@ def get_active_message(
cads_catalogue.database.Message.severity.in_(filter_types),
)
.order_by(cads_catalogue.database.Message.date.desc())
.all()
.first()
)
if messages:
if message:
return models.Message(
id=messages[0].message_uid,
id=message.message_uid,
date=None,
summary=messages[0].summary,
url=messages[0].url,
severity=messages[0].severity,
content=messages[0].content,
live=messages[0].live,
summary=message.summary,
url=message.url,
severity=message.severity,
content=message.content,
live=message.live,
)
return None

Expand All @@ -398,6 +324,8 @@ def collection_serializer(
request: fastapi.Request,
preview: bool = False,
schema_org: bool = False,
with_message: bool = True,
with_keywords: bool = True,
) -> stac_fastapi.types.stac.Collection:
"""Transform database model to stac collection."""
collection_links = generate_collection_links(
Expand All @@ -408,7 +336,7 @@ def collection_serializer(
model=db_model, base_url=config.settings.document_storage_url
)

active_message = get_active_message(db_model, session)
active_message = get_active_message(db_model, session) if with_message else None

additional_properties = {
**({"assets": assets} if assets else {}),
Expand Down Expand Up @@ -449,7 +377,11 @@ def collection_serializer(
title=db_model.title,
description=db_model.abstract,
# FIXME: this is triggering a long list of subqueries
keywords=[keyword.keyword_name for keyword in db_model.keywords],
keywords=(
[keyword.keyword_name for keyword in db_model.keywords]
if with_keywords
else []
),
# https://github.com/radiantearth/stac-spec/blob/master/collection-spec/collection-spec.md#license
# note that this small check, evenif correct, is triggering a lot of subrequests
license=(
Expand Down Expand Up @@ -537,10 +469,9 @@ def all_datasets(
request: fastapi.Request,
q: str | None = None,
kw: list[str] | None = [],
sortby: str = "relevance",
cursor: str | None = None,
sortby: extensions.CatalogueSortCriterion = extensions.CatalogueSortCriterion.update_desc,
page: int = 0,
limit: int = 999,
back: bool = False,
route_name="Get Collections",
search_stats: bool = False,
) -> models.CADSCollections:
Expand All @@ -550,27 +481,21 @@ def all_datasets(
)

route_ref = str(request.url_for(route_name))

base_url = str(request.base_url)

with self.reader.context_session() as session:
search = session.query(self.collection_table).options(
*database.deferred_columns
)
search = search_utils.apply_filters(session, search, q, kw, portals=portals)
count = search.count()
search, sort_by = apply_sorting(
search=search, sortby=sortby, cursor=cursor, limit=limit, inverse=back
search = apply_sorting_and_limit(
search=search, q=q, sortby=sortby, page=page, limit=limit
)
collections = search.all()

# Filter function always returns an item more than the limit to know if there is a next/prev page
# But response is build on effective page size
if len(collections) <= limit:
results = collections
else:
results = collections[:-1]

if len(results) == 0:
if len(collections) == 0 and route_name != "Get Collections":
# For canonical STAC requests to /collections, we don't want to raise a 404
raise stac_fastapi.types.errors.NotFoundError(
"Search does not match any dataset"
)
Expand All @@ -579,7 +504,7 @@ def all_datasets(
collection_serializer(
collection, session=session, request=request, preview=True
)
for collection in (results if not back else reversed(results))
for collection in collections
]

links = [
Expand All @@ -601,41 +526,36 @@ def all_datasets(
]

next_prev_links = get_next_prev_links(
collections=collections,
sort_by=sort_by,
cursor=cursor,
sortby=sortby.value,
page=page,
limit=limit,
back=back,
count=count,
)
if next_prev_links.get("next"):

if next_prev_links.get("prev"):
qs = urllib.parse.urlencode(
{
**{
k: v
for (k, v) in request.query_params.items()
if k != "back"
},
"cursor": next_prev_links["next"]["cursor"],
**{k: v for (k, v) in request.query_params.items()},
**next_prev_links["prev"],
}
)
links.append(
{
"rel": "next",
"rel": "prev",
"href": f"{route_ref}?{qs}",
"type": stac_pydantic.shared.MimeTypes.json,
}
)
if next_prev_links.get("prev"):
if next_prev_links.get("next"):
qs = urllib.parse.urlencode(
{
**{k: v for (k, v) in request.query_params.items()},
"cursor": next_prev_links["prev"]["cursor"],
"back": "true",
**next_prev_links["next"],
}
)
links.append(
{
"rel": "prev",
"rel": "next",
"href": f"{route_ref}?{qs}",
"type": stac_pydantic.shared.MimeTypes.json,
}
Expand Down
Loading
Loading