From ccb7509c1f8a9ddd2a2d488725f7a20145e3f167 Mon Sep 17 00:00:00 2001 From: Frank Anema <33519926+Conengmo@users.noreply.github.com> Date: Sun, 29 Dec 2024 15:36:13 +0100 Subject: [PATCH] Fix type hints after adding Branca type checking (#2060) * remove render() return types * fix TypeBounds * missing return statement * split TypeBounds in input and return types * deal with bounds from args to return * fix VegaLite typing * geojsondetail assert parent is geojson * bin_edges in choropleth * geojson/topojson in choropleth * colormap type in ColorLine * ruff check * black * fix circular import --- folium/elements.py | 2 +- folium/features.py | 78 ++++++++++++------- folium/folium.py | 2 +- folium/map.py | 13 ++-- .../plugins/overlapping_marker_spiderfier.py | 2 +- folium/raster_layers.py | 16 ++-- folium/utilities.py | 13 +++- tests/test_utilities.py | 15 ++++ 8 files changed, 96 insertions(+), 45 deletions(-) diff --git a/folium/elements.py b/folium/elements.py index aa1dbd884..3dbbc82ff 100644 --- a/folium/elements.py +++ b/folium/elements.py @@ -12,7 +12,7 @@ class JSCSSMixin(Element): default_js: List[Tuple[str, str]] = [] default_css: List[Tuple[str, str]] = [] - def render(self, **kwargs) -> None: + def render(self, **kwargs): figure = self.get_root() assert isinstance( figure, Figure diff --git a/folium/features.py b/folium/features.py index cba87e8ef..83a7b7fa4 100644 --- a/folium/features.py +++ b/folium/features.py @@ -12,7 +12,15 @@ import numpy as np import requests from branca.colormap import ColorMap, LinearColormap, StepColormap -from branca.element import Element, Figure, Html, IFrame, JavascriptLink, MacroElement +from branca.element import ( + Div, + Element, + Figure, + Html, + IFrame, + JavascriptLink, + MacroElement, +) from branca.utilities import color_brewer from folium.elements import JSCSSMixin @@ -20,6 +28,8 @@ from folium.map import FeatureGroup, Icon, Layer, Marker, Popup, Tooltip from folium.template import Template from folium.utilities import ( + TypeBoundsReturn, + TypeContainer, TypeJsonValue, TypeLine, TypePathOptions, @@ -165,7 +175,7 @@ def __init__( self.top = _parse_size(top) self.position = position - def render(self, **kwargs) -> None: + def render(self, **kwargs): """Renders the HTML representation of the element.""" super().render(**kwargs) @@ -284,9 +294,15 @@ def __init__( self.top = _parse_size(top) self.position = position - def render(self, **kwargs) -> None: + def render(self, **kwargs): """Renders the HTML representation of the element.""" - self._parent.html.add_child( + parent = self._parent + if not isinstance(parent, (Figure, Div, Popup)): + raise TypeError( + "VegaLite elements can only be added to a Figure, Div, or Popup" + ) + + parent.html.add_child( Element( Template( """ @@ -331,7 +347,7 @@ def render(self, **kwargs) -> None: embed_vegalite = embed_mapping.get( self.vegalite_major_version, self._embed_vegalite_v2 ) - embed_vegalite(figure) + embed_vegalite(figure=figure, parent=parent) @property def vegalite_major_version(self) -> Optional[int]: @@ -342,8 +358,8 @@ def vegalite_major_version(self) -> Optional[int]: return int(schema.split("/")[-1].split(".")[0].lstrip("v")) - def _embed_vegalite_v5(self, figure: Figure) -> None: - self._vega_embed() + def _embed_vegalite_v5(self, figure: Figure, parent: TypeContainer) -> None: + self._vega_embed(parent=parent) figure.header.add_child( JavascriptLink("https://cdn.jsdelivr.net/npm//vega@5"), name="vega" @@ -356,8 +372,8 @@ def _embed_vegalite_v5(self, figure: Figure) -> None: name="vega-embed", ) - def _embed_vegalite_v4(self, figure: Figure) -> None: - self._vega_embed() + def _embed_vegalite_v4(self, figure: Figure, parent: TypeContainer) -> None: + self._vega_embed(parent=parent) figure.header.add_child( JavascriptLink("https://cdn.jsdelivr.net/npm//vega@5"), name="vega" @@ -370,8 +386,8 @@ def _embed_vegalite_v4(self, figure: Figure) -> None: name="vega-embed", ) - def _embed_vegalite_v3(self, figure: Figure) -> None: - self._vega_embed() + def _embed_vegalite_v3(self, figure: Figure, parent: TypeContainer) -> None: + self._vega_embed(parent=parent) figure.header.add_child( JavascriptLink("https://cdn.jsdelivr.net/npm/vega@4"), name="vega" @@ -384,8 +400,8 @@ def _embed_vegalite_v3(self, figure: Figure) -> None: name="vega-embed", ) - def _embed_vegalite_v2(self, figure: Figure) -> None: - self._vega_embed() + def _embed_vegalite_v2(self, figure: Figure, parent: TypeContainer) -> None: + self._vega_embed(parent=parent) figure.header.add_child( JavascriptLink("https://cdn.jsdelivr.net/npm/vega@3"), name="vega" @@ -398,8 +414,8 @@ def _embed_vegalite_v2(self, figure: Figure) -> None: name="vega-embed", ) - def _vega_embed(self) -> None: - self._parent.script.add_child( + def _vega_embed(self, parent: TypeContainer) -> None: + parent.script.add_child( Element( Template( """ @@ -412,8 +428,8 @@ def _vega_embed(self) -> None: name=self.get_name(), ) - def _embed_vegalite_v1(self, figure: Figure) -> None: - self._parent.script.add_child( + def _embed_vegalite_v1(self, figure: Figure, parent: TypeContainer) -> None: + parent.script.add_child( Element( Template( """ @@ -436,19 +452,19 @@ def _embed_vegalite_v1(self, figure: Figure) -> None: figure.header.add_child( JavascriptLink("https://cdnjs.cloudflare.com/ajax/libs/vega/2.6.5/vega.js"), name="vega", - ) # noqa + ) figure.header.add_child( JavascriptLink( "https://cdnjs.cloudflare.com/ajax/libs/vega-lite/1.3.1/vega-lite.js" ), name="vega-lite", - ) # noqa + ) figure.header.add_child( JavascriptLink( "https://cdnjs.cloudflare.com/ajax/libs/vega-embed/2.2.0/vega-embed.js" ), name="vega-embed", - ) # noqa + ) class GeoJson(Layer): @@ -820,7 +836,7 @@ def _get_self_bounds(self) -> List[List[Optional[float]]]: """ return get_bounds(self.data, lonlat=True) - def render(self, **kwargs) -> None: + def render(self, **kwargs): self.parent_map = get_obj_in_upper_tree(self, Map) # Need at least one feature, otherwise style mapping fails if (self.style or self.highlight) and self.data["features"]: @@ -1041,12 +1057,12 @@ def recursive_get(data, keys): self.style_function(feature) ) # noqa - def render(self, **kwargs) -> None: + def render(self, **kwargs): """Renders the HTML representation of the element.""" self.style_data() super().render(**kwargs) - def get_bounds(self) -> List[List[float]]: + def get_bounds(self) -> TypeBoundsReturn: """ Computes the bounds of the object itself (not including it's children) in the form [[lat_min, lon_min], [lat_max, lon_max]] @@ -1146,6 +1162,7 @@ def __init__( def warn_for_geometry_collections(self) -> None: """Checks for GeoJson GeometryCollection features to warn user about incompatibility.""" + assert isinstance(self._parent, GeoJson) geom_collections = [ feature.get("properties") if feature.get("properties") is not None else key for key, feature in enumerate(self._parent.data["features"]) @@ -1160,7 +1177,7 @@ def warn_for_geometry_collections(self) -> None: UserWarning, ) - def render(self, **kwargs) -> None: + def render(self, **kwargs): """Renders the HTML representation of the element.""" figure = self.get_root() if isinstance(self._parent, GeoJson): @@ -1565,7 +1582,7 @@ def __init__( color_range = color_brewer(fill_color, n=nb_bins) self.color_scale = StepColormap( color_range, - index=bin_edges, + index=list(bin_edges), vmin=bins_min, vmax=bins_max, caption=legend_name, @@ -1625,7 +1642,7 @@ def highlight_function(x): return {"weight": line_weight + 2, "fillOpacity": fill_opacity + 0.2} if topojson: - self.geojson = TopoJson( + self.geojson: Union[TopoJson, GeoJson] = TopoJson( geo_data, topojson, style_function=style_function, @@ -1657,7 +1674,7 @@ def _get_by_key(cls, obj: Union[dict, list], key: str) -> Union[float, str, None else: return value - def render(self, **kwargs) -> None: + def render(self, **kwargs): """Render the GeoJson/TopoJson and color scale objects.""" if self.color_scale: # ColorMap needs Map as its parent @@ -1963,8 +1980,13 @@ def __init__( vmin=min(colors), vmax=max(colors), ).to_step(nb_steps) - else: + elif isinstance(colormap, StepColormap): cm = colormap + else: + raise TypeError( + f"Unexpected type for argument `colormap`: {type(colormap)}" + ) + out: Dict[str, List[List[List[float]]]] = {} for (lat1, lng1), (lat2, lng2), color in zip(coords[:-1], coords[1:], colors): out.setdefault(cm(color), []).append([[lat1, lng1], [lat2, lng2]]) diff --git a/folium/folium.py b/folium/folium.py index e48049371..12510ecb4 100644 --- a/folium/folium.py +++ b/folium/folium.py @@ -377,7 +377,7 @@ def _repr_png_(self) -> Optional[bytes]: return None return self._to_png() - def render(self, **kwargs) -> None: + def render(self, **kwargs): """Renders the HTML representation of the element.""" figure = self.get_root() assert isinstance( diff --git a/folium/map.py b/folium/map.py index 758e64153..36c465594 100644 --- a/folium/map.py +++ b/folium/map.py @@ -5,7 +5,7 @@ import warnings from collections import OrderedDict -from typing import TYPE_CHECKING, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Optional, Sequence, Union, cast from branca.element import Element, Figure, Html, MacroElement @@ -14,6 +14,7 @@ from folium.utilities import ( JsCode, TypeBounds, + TypeBoundsReturn, TypeJsonValue, escape_backticks, parse_options, @@ -221,7 +222,7 @@ def reset(self) -> None: self.base_layers = OrderedDict() self.overlays = OrderedDict() - def render(self, **kwargs) -> None: + def render(self, **kwargs): """Renders the HTML representation of the element.""" self.reset() for item in self._parent._children.values(): @@ -396,15 +397,15 @@ def __init__( tooltip if isinstance(tooltip, Tooltip) else Tooltip(str(tooltip)) ) - def _get_self_bounds(self) -> List[List[float]]: + def _get_self_bounds(self) -> TypeBoundsReturn: """Computes the bounds of the object itself. Because a marker has only single coordinates, we repeat them. """ assert self.location is not None - return [self.location, self.location] + return cast(TypeBoundsReturn, [self.location, self.location]) - def render(self) -> None: + def render(self): if self.location is None: raise ValueError( f"{self._name} location must be assigned when added directly to map." @@ -492,7 +493,7 @@ def __init__( **kwargs, ) - def render(self, **kwargs) -> None: + def render(self, **kwargs): """Renders the HTML representation of the element.""" for name, child in self._children.items(): child.render(**kwargs) diff --git a/folium/plugins/overlapping_marker_spiderfier.py b/folium/plugins/overlapping_marker_spiderfier.py index 70fcea412..e6335aaf4 100644 --- a/folium/plugins/overlapping_marker_spiderfier.py +++ b/folium/plugins/overlapping_marker_spiderfier.py @@ -92,7 +92,7 @@ def add_to( ) -> Element: self._parent = parent self.markers = self._get_all_markers(parent) - super().add_to(parent, name=name, index=index) + return super().add_to(parent, name=name, index=index) def _get_all_markers(self, element: Element) -> list: markers = [] diff --git a/folium/raster_layers.py b/folium/raster_layers.py index bbb2f0b9f..0ff6ef994 100644 --- a/folium/raster_layers.py +++ b/folium/raster_layers.py @@ -12,9 +12,11 @@ from folium.template import Template from folium.utilities import ( TypeBounds, + TypeBoundsReturn, TypeJsonValue, image_to_url, mercator_transform, + normalize_bounds_type, parse_options, remove_empty, ) @@ -246,7 +248,7 @@ class ImageOverlay(Layer): * If string, it will be written directly in the output file. * If file, it's content will be converted as embedded in the output file. * If array-like, it will be converted to PNG base64 string and embedded in the output. - bounds: list + bounds: list/tuple of list/tuple of float Image bounds on the map in the form [[lat_min, lon_min], [lat_max, lon_max]] opacity: float, default Leaflet's default (1.0) @@ -319,7 +321,7 @@ def __init__( self.url = image_to_url(image, origin=origin, colormap=colormap) - def render(self, **kwargs) -> None: + def render(self, **kwargs): super().render() figure = self.get_root() @@ -344,13 +346,13 @@ def render(self, **kwargs) -> None: Element(pixelated), name="leaflet-image-layer" ) # noqa - def _get_self_bounds(self) -> TypeBounds: + def _get_self_bounds(self) -> TypeBoundsReturn: """ Computes the bounds of the object itself (not including it's children) in the form [[lat_min, lon_min], [lat_max, lon_max]]. """ - return self.bounds + return normalize_bounds_type(self.bounds) class VideoOverlay(Layer): @@ -361,7 +363,7 @@ class VideoOverlay(Layer): ---------- video_url: str URL of the video - bounds: list + bounds: list/tuple of list/tuple of float Video bounds on the map in the form [[lat_min, lon_min], [lat_max, lon_max]] autoplay: bool, default True @@ -411,10 +413,10 @@ def __init__( self.bounds = bounds self.options = remove_empty(autoplay=autoplay, loop=loop, **kwargs) - def _get_self_bounds(self) -> TypeBounds: + def _get_self_bounds(self) -> TypeBoundsReturn: """ Computes the bounds of the object itself (not including it's children) in the form [[lat_min, lon_min], [lat_max, lon_max]] """ - return self.bounds + return normalize_bounds_type(self.bounds) diff --git a/folium/utilities.py b/folium/utilities.py index a730417c3..9210bf59e 100644 --- a/folium/utilities.py +++ b/folium/utilities.py @@ -9,6 +9,7 @@ import uuid from contextlib import contextmanager from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -24,7 +25,7 @@ from urllib.parse import urlparse, uses_netloc, uses_params, uses_relative import numpy as np -from branca.element import Element, Figure +from branca.element import Div, Element, Figure # import here for backwards compatibility from branca.utilities import ( # noqa F401 @@ -40,6 +41,9 @@ except ImportError: pd = None +if TYPE_CHECKING: + from .features import Popup + TypeLine = Iterable[Sequence[float]] TypeMultiLine = Union[TypeLine, Iterable[TypeLine]] @@ -50,6 +54,9 @@ TypePathOptions = Union[bool, str, float, None] TypeBounds = Sequence[Sequence[float]] +TypeBoundsReturn = List[List[Optional[float]]] + +TypeContainer = Union[Figure, Div, "Popup"] _VALID_URLS = set(uses_relative + uses_netloc + uses_params) @@ -325,6 +332,10 @@ def get_bounds( return bounds +def normalize_bounds_type(bounds: TypeBounds) -> TypeBoundsReturn: + return [[float(x) if x is not None else None for x in y] for y in bounds] + + def camelize(key: str) -> str: """Convert a python_style_variable_name to lowerCamelCase. diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 2218bba73..a1f10be36 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -12,6 +12,7 @@ get_obj_in_upper_tree, if_pandas_df_convert_to_numpy, javascript_identifier_path_to_array_notation, + normalize_bounds_type, parse_font_size, parse_options, validate_location, @@ -133,6 +134,20 @@ def test_if_pandas_df_convert_to_numpy(): assert if_pandas_df_convert_to_numpy(expected) is expected +@pytest.mark.parametrize( + "bounds, expected", + [ + ([[1, 2], [3, 4]], [[1.0, 2.0], [3.0, 4.0]]), + ([[None, 2], [3, None]], [[None, 2.0], [3.0, None]]), + ([[1.1, 2.2], [3.3, 4.4]], [[1.1, 2.2], [3.3, 4.4]]), + ([[None, None], [None, None]], [[None, None], [None, None]]), + ([[0, -1], [-2, 3]], [[0.0, -1.0], [-2.0, 3.0]]), + ], +) +def test_normalize_bounds_type(bounds, expected): + assert normalize_bounds_type(bounds) == expected + + def test_camelize(): assert camelize("variable_name") == "variableName" assert camelize("variableName") == "variableName"