diff --git a/.github/pages/make_switcher.py b/.github/pages/make_switcher.py index e2c8e6f6..6d90f490 100755 --- a/.github/pages/make_switcher.py +++ b/.github/pages/make_switcher.py @@ -3,28 +3,27 @@ from argparse import ArgumentParser from pathlib import Path from subprocess import CalledProcessError, check_output -from typing import List, Optional -def report_output(stdout: bytes, label: str) -> List[str]: +def report_output(stdout: bytes, label: str) -> list[str]: ret = stdout.decode().strip().split("\n") print(f"{label}: {ret}") return ret -def get_branch_contents(ref: str) -> List[str]: +def get_branch_contents(ref: str) -> list[str]: """Get the list of directories in a branch.""" stdout = check_output(["git", "ls-tree", "-d", "--name-only", ref]) return report_output(stdout, "Branch contents") -def get_sorted_tags_list() -> List[str]: +def get_sorted_tags_list() -> list[str]: """Get a list of sorted tags in descending order from the repository.""" stdout = check_output(["git", "tag", "-l", "--sort=-v:refname"]) return report_output(stdout, "Tags list") -def get_versions(ref: str, add: Optional[str]) -> List[str]: +def get_versions(ref: str, add: str | None) -> list[str]: """Generate the file containing the list of all GitHub Pages builds.""" # Get the directories (i.e. builds) from the GitHub Pages branch try: @@ -41,7 +40,7 @@ def get_versions(ref: str, add: Optional[str]) -> List[str]: tags = get_sorted_tags_list() # Make the sorted versions list from main branches and tags - versions: List[str] = [] + versions: list[str] = [] for version in ["master", "main"] + tags: if version in builds: versions.append(version) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7e17ec93..ce814ba1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,7 +21,7 @@ jobs: strategy: matrix: runs-on: ["ubuntu-latest"] # can add windows-latest, macos-latest - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11"] include: # Include one that runs in the dev environment - runs-on: "ubuntu-latest" diff --git a/docs/how-to/iterate-a-spec.rst b/docs/how-to/iterate-a-spec.rst index f966f77a..9d9333a2 100644 --- a/docs/how-to/iterate-a-spec.rst +++ b/docs/how-to/iterate-a-spec.rst @@ -18,9 +18,9 @@ frame. You can get these by using the `Spec.midpoints()` method to produce a >>> for d in spec.midpoints(): ... print(d) ... -{'x': 1.0} -{'x': 1.5} -{'x': 2.0} +{'x': np.float64(1.0)} +{'x': np.float64(1.5)} +{'x': np.float64(2.0)} This is simple, but not particularly performant, as the numpy arrays of points are unpacked point by point into point dictionaries diff --git a/pyproject.toml b/pyproject.toml index 799b38db..2d064382 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,31 +7,19 @@ name = "scanspec" classifiers = [ "Development Status :: 3 - Alpha", "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", ] description = "Specify step and flyscan paths in a serializable, efficient and Pythonic way" -dependencies = [ - "numpy>=1.19.3", - "click==8.1.3", - "pydantic<2.0", - "typing_extensions", -] +dependencies = ["numpy>=2", "click>=8.1", "pydantic<2.0", "httpx==0.26.0"] dynamic = ["version"] license.file = "LICENSE" readme = "README.md" -requires-python = ">=3.7" +requires-python = ">=3.10" [project.optional-dependencies] # Plotting -plotting = [ - # make sure a python 3.9 compatible scipy and matplotlib are selected - "scipy>=1.5.4", - "matplotlib>=3.2.2", -] +plotting = ["scipy", "matplotlib"] # REST service support service = ["fastapi==0.99", "uvicorn"] # For development tests/docs @@ -131,8 +119,6 @@ extend-select = [ "I", # isort - https://docs.astral.sh/ruff/rules/#isort-i "UP", # pyupgrade - https://docs.astral.sh/ruff/rules/#pyupgrade-up ] -# We use pydantic, so don't upgrade to py3.10 syntax yet -pyupgrade.keep-runtime-typing = true ignore = [ "B008", # We use function calls in service arguments ] diff --git a/src/scanspec/core.py b/src/scanspec/core.py index 21012737..ca162565 100644 --- a/src/scanspec/core.py +++ b/src/scanspec/core.py @@ -1,25 +1,12 @@ from __future__ import annotations +from collections.abc import Callable, Iterable, Iterator, Sequence from dataclasses import field -from typing import ( - Any, - Callable, - Dict, - Generic, - Iterable, - Iterator, - List, - Optional, - Sequence, - Type, - TypeVar, - Union, -) +from typing import Any, Generic, Literal, TypeVar, Union import numpy as np from pydantic import BaseConfig, Extra, Field, ValidationError, create_model from pydantic.error_wrappers import ErrorWrapper -from typing_extensions import Literal __all__ = [ "if_instance_do", @@ -43,11 +30,11 @@ class StrictConfig(BaseConfig): def discriminated_union_of_subclasses( - super_cls: Optional[Type] = None, + super_cls: type | None = None, *, discriminator: str = "type", - config: Optional[Type[BaseConfig]] = None, -) -> Union[Type, Callable[[Type], Type]]: + config: type[BaseConfig] | None = None, +) -> type | Callable[[type], type]: """Add all subclasses of super_cls to a discriminated union. For all subclasses of super_cls, add a discriminator field to identify @@ -114,7 +101,7 @@ def calculate(self) -> int: subclasses. Defaults to None. Returns: - Union[Type, Callable[[Type], Type]]: A decorator that adds the necessary + Type | Callable[[Type], Type]: A decorator that adds the necessary functionality to a class. """ @@ -130,12 +117,12 @@ def wrap(cls): def _discriminated_union_of_subclasses( - super_cls: Type, + super_cls: type, discriminator: str, - config: Optional[Type[BaseConfig]] = None, -) -> Union[Type, Callable[[Type], Type]]: - super_cls._ref_classes = set() - super_cls._model = None + config: type[BaseConfig] | None = None, +) -> type | Callable[[type], type]: + super_cls._ref_classes = set() # type: ignore + super_cls._model = None # type: ignore def __init_subclass__(cls) -> None: # Keep track of inherting classes in super class @@ -157,7 +144,7 @@ def __validate__(cls, v: Any) -> Any: # needs to be done once, after all subclasses have been # declared if cls._model is None: - root = Union[tuple(cls._ref_classes)] # type: ignore + root = Union[tuple(cls._ref_classes)] # type: ignore # noqa cls._model = create_model( super_cls.__name__, __root__=(root, Field(..., discriminator=discriminator)), @@ -185,7 +172,7 @@ def __validate__(cls, v: Any) -> Any: return super_cls -def if_instance_do(x: Any, cls: Type, func: Callable): +def if_instance_do(x: Any, cls: type, func: Callable): """If x is of type cls then return func(x), otherwise return NotImplemented. Used as a helper when implementing operator overloading. @@ -201,7 +188,7 @@ def if_instance_do(x: Any, cls: Type, func: Callable): #: Map of axes to float ndarray of points #: E.g. {xmotor: array([0, 1, 2]), ymotor: array([2, 2, 2])} -AxesPoints = Dict[Axis, np.ndarray] +AxesPoints = dict[Axis, np.ndarray] class Frames(Generic[Axis]): @@ -234,9 +221,9 @@ class Frames(Generic[Axis]): def __init__( self, midpoints: AxesPoints[Axis], - lower: Optional[AxesPoints[Axis]] = None, - upper: Optional[AxesPoints[Axis]] = None, - gap: Optional[np.ndarray] = None, + lower: AxesPoints[Axis] | None = None, + upper: AxesPoints[Axis] | None = None, + gap: np.ndarray | None = None, ): #: The midpoints of scan frames for each axis self.midpoints = midpoints @@ -253,7 +240,9 @@ def __init__( # We have a gap if upper[i] != lower[i+1] for any axes axes_gap = [ np.roll(upper, 1) != lower - for upper, lower in zip(self.upper.values(), self.lower.values()) + for upper, lower in zip( + self.upper.values(), self.lower.values(), strict=False + ) ] self.gap = np.logical_or.reduce(axes_gap) # Check all axes and ordering are the same @@ -270,7 +259,7 @@ def __init__( lengths.add(len(self.gap)) assert len(lengths) <= 1, f"Mismatching lengths {list(lengths)}" - def axes(self) -> List[Axis]: + def axes(self) -> list[Axis]: """The axes which will move during the scan. These will be present in `midpoints`, `lower` and `upper`. @@ -300,7 +289,7 @@ def extract_dict(ds: Iterable[AxesPoints[Axis]]) -> AxesPoints[Axis]: return {k: v[dim_indices] for k, v in d.items()} return {} - def extract_gap(gaps: Iterable[np.ndarray]) -> Optional[np.ndarray]: + def extract_gap(gaps: Iterable[np.ndarray]) -> np.ndarray | None: for gap in gaps: if not calculate_gap: return gap[dim_indices] @@ -371,7 +360,7 @@ def zip_gap(gaps: Sequence[np.ndarray]) -> np.ndarray: def _merge_frames( *stack: Frames[Axis], dict_merge=Callable[[Sequence[AxesPoints[Axis]]], AxesPoints[Axis]], # type: ignore - gap_merge=Callable[[Sequence[np.ndarray]], Optional[np.ndarray]], + gap_merge=Callable[[Sequence[np.ndarray]], np.ndarray | None], ) -> Frames[Axis]: types = {type(fs) for fs in stack} assert len(types) == 1, f"Mismatching types for {stack}" @@ -397,9 +386,9 @@ class SnakedFrames(Frames[Axis]): def __init__( self, midpoints: AxesPoints[Axis], - lower: Optional[AxesPoints[Axis]] = None, - upper: Optional[AxesPoints[Axis]] = None, - gap: Optional[np.ndarray] = None, + lower: AxesPoints[Axis] | None = None, + upper: AxesPoints[Axis] | None = None, + gap: np.ndarray | None = None, ): super().__init__(midpoints, lower=lower, upper=upper, gap=gap) # Override first element of gap to be True, as subsequent runs @@ -431,7 +420,7 @@ def extract(self, indices: np.ndarray, calculate_gap=True) -> Frames[Axis]: length = len(self) backwards = (indices // length) % 2 snake_indices = np.where(backwards, (length - 1) - indices, indices) % length - cls: Type[Frames[Any]] + cls: type[Frames[Any]] if not calculate_gap: cls = Frames gap = self.gap[np.where(backwards, length - indices, indices) % length] @@ -464,7 +453,7 @@ def gap_between_frames(frames1: Frames[Axis], frames2: Frames[Axis]) -> bool: return any(frames1.upper[a][-1] != frames2.lower[a][0] for a in frames1.axes()) -def squash_frames(stack: List[Frames[Axis]], check_path_changes=True) -> Frames[Axis]: +def squash_frames(stack: list[Frames[Axis]], check_path_changes=True) -> Frames[Axis]: """Squash a stack of nested Frames into a single one. Args: @@ -530,7 +519,7 @@ class Path(Generic[Axis]): """ def __init__( - self, stack: List[Frames[Axis]], start: int = 0, num: Optional[int] = None + self, stack: list[Frames[Axis]], start: int = 0, num: int | None = None ): #: The Frames stack describing the scan, from slowest to fastest moving self.stack = stack @@ -544,7 +533,7 @@ def __init__( if num is not None and start + num < self.end_index: self.end_index = start + num - def consume(self, num: Optional[int] = None) -> Frames[Axis]: + def consume(self, num: int | None = None) -> Frames[Axis]: """Consume at most num frames from the Path and return as a Frames object. >>> fx = SnakedFrames({"x": np.array([1, 2])}) @@ -613,18 +602,18 @@ class Midpoints(Generic[Axis]): >>> fy = Frames({"y": np.array([3, 4])}) >>> mp = Midpoints([fy, fx]) >>> for p in mp: print(p) - {'y': 3, 'x': 1} - {'y': 3, 'x': 2} - {'y': 4, 'x': 2} - {'y': 4, 'x': 1} + {'y': np.int64(3), 'x': np.int64(1)} + {'y': np.int64(3), 'x': np.int64(2)} + {'y': np.int64(4), 'x': np.int64(2)} + {'y': np.int64(4), 'x': np.int64(1)} """ - def __init__(self, stack: List[Frames[Axis]]): + def __init__(self, stack: list[Frames[Axis]]): #: The stack of Frames describing the scan, from slowest to fastest moving self.stack = stack @property - def axes(self) -> List[Axis]: + def axes(self) -> list[Axis]: """The axes that will be present in each points dictionary.""" axes = [] for frames in self.stack: @@ -635,7 +624,7 @@ def __len__(self) -> int: """The number of dictionaries that will be produced if iterated over.""" return int(np.prod([len(frames) for frames in self.stack])) - def __iter__(self) -> Iterator[Dict[Axis, float]]: + def __iter__(self) -> Iterator[dict[Axis, float]]: """Yield {axis: midpoint} for each frame in the scan.""" path = Path(self.stack) while len(path): diff --git a/src/scanspec/plot.py b/src/scanspec/plot.py index 07db1949..43311663 100644 --- a/src/scanspec/plot.py +++ b/src/scanspec/plot.py @@ -1,5 +1,6 @@ +from collections.abc import Iterator from itertools import cycle -from typing import Any, Dict, Iterator, List, Optional +from typing import Any import numpy as np from matplotlib import colors, patches @@ -14,7 +15,7 @@ __all__ = ["plot_spec"] -def _plot_arrays(axes, arrays: List[np.ndarray], **kwargs): +def _plot_arrays(axes, arrays: list[np.ndarray], **kwargs): if len(arrays) > 2: axes.plot3D(arrays[2], arrays[1], arrays[0], **kwargs) elif len(arrays) == 2: @@ -38,7 +39,7 @@ def do_3d_projection(self, renderer=None): return np.min(zs) -def _plot_arrow(axes, arrays: List[np.ndarray]): +def _plot_arrow(axes, arrays: list[np.ndarray]): if len(arrays) == 1: arrays = [np.array([0, 0])] + arrays if len(arrays) == 2: @@ -58,8 +59,8 @@ def _plot_arrow(axes, arrays: List[np.ndarray]): axes.add_artist(a) -def _plot_spline(axes, ranges, arrays: List[np.ndarray], index_colours: Dict[int, str]): - scaled_arrays = [a / r for a, r in zip(arrays, ranges)] +def _plot_spline(axes, ranges, arrays: list[np.ndarray], index_colours: dict[int, str]): + scaled_arrays = [a / r for a, r in zip(arrays, ranges, strict=False)] # Define curves parametrically t = np.zeros(len(arrays[0])) t[1:] = np.sqrt(sum((arr[1:] - arr[:-1]) ** 2 for arr in scaled_arrays)) @@ -67,7 +68,7 @@ def _plot_spline(axes, ranges, arrays: List[np.ndarray], index_colours: Dict[int if t[-1] > 0: # Can't make a spline that starts and ends in the same place, so add a small # delta - for s, r in zip(scaled_arrays, ranges): + for s, r in zip(scaled_arrays, ranges, strict=False): if s[0] == s[-1]: s += np.linspace(0, r * 1e-7, len(s)) # There are no duplicated points, plot a spline @@ -76,16 +77,16 @@ def _plot_spline(axes, ranges, arrays: List[np.ndarray], index_colours: Dict[int tck, _ = interpolate.splprep(scaled_arrays, k=2, s=0) starts = sorted(index_colours) stops = starts[1:] + [len(arrays[0]) - 1] - for start, stop in zip(starts, stops): + for start, stop in zip(starts, stops, strict=False): tnew = np.linspace(t[start], t[stop], num=1001) spline = interpolate.splev(tnew, tck) # Scale the splines back to the original scaling - unscaled_splines = [a * r for a, r in zip(spline, ranges)] + unscaled_splines = [a * r for a, r in zip(spline, ranges, strict=False)] _plot_arrays(axes, unscaled_splines, color=index_colours[start]) yield unscaled_splines -def plot_spec(spec: Spec[Any], title: Optional[str] = None): +def plot_spec(spec: Spec[Any], title: str | None = None): """Plot a spec, drawing the path taken through the scan. Uses a different colour for each frame, grey for the turnarounds, and diff --git a/src/scanspec/regions.py b/src/scanspec/regions.py index 9f43cb57..201a8a4b 100644 --- a/src/scanspec/regions.py +++ b/src/scanspec/regions.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Generic, Iterator, List, Set +from collections.abc import Iterator +from typing import Generic import numpy as np from pydantic import BaseModel, Field @@ -43,7 +44,7 @@ class Region(Generic[Axis]): - ``^``: `SymmetricDifferenceOf` two Regions, midpoints present in one not both """ - def axis_sets(self) -> List[Set[Axis]]: + def axis_sets(self) -> list[set[Axis]]: """Produce the non-overlapping sets of axes this region spans.""" raise NotImplementedError(self) @@ -78,7 +79,7 @@ def get_mask(region: Region[Axis], points: AxesPoints[Axis]) -> np.ndarray: return np.ones(len(list(points.values())[0])) -def _merge_axis_sets(axis_sets: List[Set[Axis]]) -> Iterator[Set[Axis]]: +def _merge_axis_sets(axis_sets: list[set[Axis]]) -> Iterator[set[Axis]]: # Take overlapping axis sets and merge any that overlap into each # other for ks in axis_sets: # ks = key_sets - left over from a previous naming standard @@ -100,7 +101,7 @@ class CombinationOf(Region[Axis]): left: Region[Axis] = Field(description="The left-hand Region to combine") right: Region[Axis] = Field(description="The right-hand Region to combine") - def axis_sets(self) -> List[Set[Axis]]: + def axis_sets(self) -> list[set[Axis]]: axis_sets = list( _merge_axis_sets(self.left.axis_sets() + self.right.axis_sets()) ) @@ -187,7 +188,7 @@ class Range(Region[Axis]): min: float = Field(description="The minimum inclusive value in the region") max: float = Field(description="The minimum inclusive value in the region") - def axis_sets(self) -> List[Set[Axis]]: + def axis_sets(self) -> list[set[Axis]]: return [{self.axis}] def mask(self, points: AxesPoints[Axis]) -> np.ndarray: @@ -219,7 +220,7 @@ class Rectangle(Region[Axis]): description="Clockwise rotation angle of the rectangle", default=0.0 ) - def axis_sets(self) -> List[Set[Axis]]: + def axis_sets(self) -> list[set[Axis]]: return [{self.x_axis, self.y_axis}] def mask(self, points: AxesPoints[Axis]) -> np.ndarray: @@ -252,14 +253,14 @@ class Polygon(Region[Axis]): x_axis: Axis = Field(description="The name matching the x axis of the spec") y_axis: Axis = Field(description="The name matching the y axis of the spec") - x_verts: List[float] = Field( + x_verts: list[float] = Field( description="The Nx1 x coordinates of the polygons vertices", min_len=3 ) - y_verts: List[float] = Field( + y_verts: list[float] = Field( description="The Nx1 y coordinates of the polygons vertices", min_len=3 ) - def axis_sets(self) -> List[Set[Axis]]: + def axis_sets(self) -> list[set[Axis]]: return [{self.x_axis, self.y_axis}] def mask(self, points: AxesPoints[Axis]) -> np.ndarray: @@ -267,7 +268,7 @@ def mask(self, points: AxesPoints[Axis]) -> np.ndarray: y = points[self.y_axis] v1x, v1y = self.x_verts[-1], self.y_verts[-1] mask = np.full(len(x), False, dtype=np.int8) - for v2x, v2y in zip(self.x_verts, self.y_verts): + for v2x, v2y in zip(self.x_verts, self.y_verts, strict=False): # skip horizontal edges if v2y != v1y: vmask = np.full(len(x), False, dtype=np.int8) @@ -299,7 +300,7 @@ class Circle(Region[Axis]): y_middle: float = Field(description="The central y point of the circle") radius: float = Field(description="Radius of the circle", exc_min=0) - def axis_sets(self) -> List[Set[Axis]]: + def axis_sets(self) -> list[set[Axis]]: return [{self.x_axis, self.y_axis}] def mask(self, points: AxesPoints[Axis]) -> np.ndarray: @@ -334,7 +335,7 @@ class Ellipse(Region[Axis]): ) angle: float = Field(description="The angle of the ellipse (degrees)", default=0.0) - def axis_sets(self) -> List[Set[Axis]]: + def axis_sets(self) -> list[set[Axis]]: return [{self.x_axis, self.y_axis}] def mask(self, points: AxesPoints[Axis]) -> np.ndarray: diff --git a/src/scanspec/service.py b/src/scanspec/service.py index 64865fed..939bfd8a 100644 --- a/src/scanspec/service.py +++ b/src/scanspec/service.py @@ -1,7 +1,7 @@ import base64 import json +from collections.abc import Mapping from enum import Enum -from typing import List, Mapping, Optional, Tuple, Union import numpy as np from fastapi import Body, FastAPI @@ -23,7 +23,7 @@ #: A set of points, that can be returned in various formats -Points = Union[str, List[float]] +Points = str | list[float] @dataclass @@ -47,7 +47,7 @@ class PointsRequest: """A request for generated scan points.""" spec: Spec = Field(description="The spec from which to generate points") - max_frames: Optional[int] = Field( + max_frames: int | None = Field( description="The maximum number of points to return, if None will return " "as many as calculated", default=100000, @@ -95,7 +95,7 @@ class BoundsResponse(GeneratedPointsResponse): class GapResponse: """Presence of gaps in a generated scan.""" - gap: List[bool] = Field( + gap: list[bool] = Field( description="Boolean array indicating if there is a gap between each frame" ) @@ -125,7 +125,7 @@ class SmallestStepResponse: @app.post("/valid", response_model=ValidResponse) def valid( spec: Spec = Body(..., examples=[_EXAMPLE_SPEC]), -) -> Union[ValidResponse, JSONResponse]: +) -> ValidResponse | JSONResponse: """Validate wether a ScanSpec can produce a viable scan. Args: @@ -250,7 +250,7 @@ def smallest_step( # -def _to_chunk(request: PointsRequest) -> Tuple[Frames, int]: +def _to_chunk(request: PointsRequest) -> tuple[Frames, int]: spec = Spec.deserialize(request.spec) dims = spec.calculate() # Grab dimensions from spec path = Path(dims) # Convert to a path @@ -296,7 +296,7 @@ def _format_axes_points( raise KeyError(f"Unknown format: {format}") -def _reduce_frames(stack: List[Frames[str]], max_frames: int) -> Path: +def _reduce_frames(stack: list[Frames[str]], max_frames: int) -> Path: """Removes frames from a spec so len(path) < max_frames. Args: @@ -327,7 +327,7 @@ def _sub_sample(frames: Frames[str], ratio: float) -> Frames: return frames.extract(indexes, calculate_gap=False) -def _calc_smallest_step(points: List[np.ndarray]) -> float: +def _calc_smallest_step(points: list[np.ndarray]) -> float: # Calc abs diffs of all axes, ignoring any zero values absolute_diffs = [_abs_diffs(axis_midpoints) for axis_midpoints in points] # Normalize and remove zeros diff --git a/src/scanspec/specs.py b/src/scanspec/specs.py index 03d0674b..cbafb126 100644 --- a/src/scanspec/specs.py +++ b/src/scanspec/specs.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Callable, Mapping from dataclasses import asdict -from typing import Any, Callable, Dict, Generic, List, Mapping, Optional, Tuple, Type +from typing import Any, Generic import numpy as np from pydantic import Field, parse_obj_as @@ -55,14 +56,14 @@ class Spec(Generic[Axis]): - ``~``: `Snake` the Spec, reversing every other iteration of it """ - def axes(self) -> List[Axis]: + def axes(self) -> list[Axis]: """Return the list of axes that are present in the scan. Ordered from slowest moving to fastest moving. """ raise NotImplementedError(self) - def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: """Produce a stack of nested `Frames` that form the scan. Ordered from slowest moving to fastest moving. @@ -77,7 +78,7 @@ def midpoints(self) -> Midpoints[Axis]: """Return `Midpoints` that can be iterated point by point.""" return Midpoints(self.calculate(bounds=False)) - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: """Return the final, simplified shape of the scan.""" return tuple(len(dim) for dim in self.calculate()) @@ -127,10 +128,10 @@ class Product(Spec[Axis]): outer: Spec[Axis] = Field(description="Will be executed once") inner: Spec[Axis] = Field(description="Will be executed len(outer) times") - def axes(self) -> List: + def axes(self) -> list: return self.outer.axes() + self.inner.axes() - def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: frames_outer = self.outer.calculate(bounds=False, nested=nested) frames_inner = self.inner.calculate(bounds, nested=True) return frames_outer + frames_inner @@ -166,10 +167,10 @@ class Repeat(Spec[Axis]): default=True, ) - def axes(self) -> List: + def axes(self) -> list: return [] - def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: return [Frames({}, gap=np.full(self.num, self.gap))] @@ -203,10 +204,10 @@ class Zip(Spec[Axis]): description="The right-hand Spec to Zip, will appear later in axes" ) - def axes(self) -> List: + def axes(self) -> list: return self.left.axes() + self.right.axes() - def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: frames_left = self.left.calculate(bounds, nested) frames_right = self.right.calculate(bounds, nested) assert len(frames_left) >= len( @@ -225,14 +226,14 @@ def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: # Left pad frames_right with Nones so they are the same size npad = len(frames_left) - len(frames_right) - padded_right: List[Optional[Frames[Axis]]] = [None] * npad + padded_right: list[Frames[Axis] | None] = [None] * npad # Mypy doesn't like this because lists are invariant: # https://github.com/python/mypy/issues/4244 padded_right += frames_right # type: ignore # Work through, zipping them together one by one frames = [] - for left, right in zip(frames_left, padded_right): + for left, right in zip(frames_left, padded_right, strict=False): if right is None: combined = left else: @@ -271,10 +272,10 @@ class Mask(Spec[Axis]): default=True, ) - def axes(self) -> List: + def axes(self) -> list: return self.spec.axes() - def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: frames = self.spec.calculate(bounds, nested) for axis_set in self.region.axis_sets(): # Find the start and end index of any dimensions containing these axes @@ -329,10 +330,10 @@ class Snake(Spec[Axis]): description="The Spec to run in reverse every other iteration" ) - def axes(self) -> List: + def axes(self) -> list: return self.spec.axes() - def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: return [ SnakedFrames.from_frames(segment) for segment in self.spec.calculate(bounds, nested) @@ -368,14 +369,14 @@ class Concat(Spec[Axis]): default=True, ) - def axes(self) -> List: + def axes(self) -> list: left_axes, right_axes = self.left.axes(), self.right.axes() # Assuming the axes are the same, the order does not matter, we inherit the # order from the left-hand side. See also scanspec.core.concat. assert set(left_axes) == set(right_axes), f"axes {left_axes} != {right_axes}" return left_axes - def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: dim_left = squash_frames( self.left.calculate(bounds, nested), nested and self.check_path_changes ) @@ -406,21 +407,21 @@ class Squash(Spec[Axis]): default=True, ) - def axes(self) -> List: + def axes(self) -> list: return self.spec.axes() - def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: dims = self.spec.calculate(bounds, nested) dim = squash_frames(dims, nested and self.check_path_changes) return [dim] def _dimensions_from_indexes( - func: Callable[[np.ndarray], Dict[Axis, np.ndarray]], - axes: List, + func: Callable[[np.ndarray], dict[Axis, np.ndarray]], + axes: list, num: int, bounds: bool, -) -> List[Frames[Axis]]: +) -> list[Frames[Axis]]: # Calc num midpoints (fences) from 0.5 .. num - 0.5 midpoints_calc = func(np.linspace(0.5, num - 0.5, num)) midpoints = {a: midpoints_calc[a] for a in axes} @@ -458,10 +459,10 @@ class Line(Spec[Axis]): stop: float = Field(description="Midpoint of the last point of the line") num: int = Field(min=1, description="Number of frames to produce") - def axes(self) -> List: + def axes(self) -> list: return [self.axis] - def _line_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]: + def _line_from_indexes(self, indexes: np.ndarray) -> dict[Axis, np.ndarray]: if self.num == 1: # Only one point, stop-start gives length of one point step = self.stop - self.start @@ -473,7 +474,7 @@ def _line_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]: first = self.start - step / 2 return {self.axis: indexes * step + first} - def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: return _dimensions_from_indexes( self._line_from_indexes, self.axes(), self.num, bounds ) @@ -524,7 +525,7 @@ class Static(Spec[Axis]): @classmethod def duration( - cls: Type[Static], + cls: type[Static], duration: float = Field(description="The duration of each static point"), num: int = Field(min=1, description="Number of frames to produce", default=1), ) -> Static[str]: @@ -538,13 +539,13 @@ def duration( """ return cls(DURATION, duration, num) - def axes(self) -> List: + def axes(self) -> list: return [self.axis] - def _repeats_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]: + def _repeats_from_indexes(self, indexes: np.ndarray) -> dict[Axis, np.ndarray]: return {self.axis: np.full(len(indexes), self.value)} - def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: return _dimensions_from_indexes( self._repeats_from_indexes, self.axes(), self.num, bounds ) @@ -577,11 +578,11 @@ class Spiral(Spec[Axis]): description="How much to rotate the angle of the spiral", default=0.0 ) - def axes(self) -> List[Axis]: + def axes(self) -> list[Axis]: # TODO: reversed from __init__ args, a good idea? return [self.y_axis, self.x_axis] - def _spiral_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]: + def _spiral_from_indexes(self, indexes: np.ndarray) -> dict[Axis, np.ndarray]: # simplest spiral equation: r = phi # we want point spacing across area to be the same as between rings # so: sqrt(area / num) = ring_spacing @@ -598,7 +599,7 @@ def _spiral_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]: self.x_axis: self.x_start + x_scale * phi * np.sin(phi + self.rotate), } - def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: return _dimensions_from_indexes( self._spiral_from_indexes, self.axes(), self.num, bounds ) @@ -669,7 +670,7 @@ def step(spec: Spec[Axis], duration: float, num: int = 1) -> Spec[Axis]: return spec * Static.duration(duration, num) -def get_constant_duration(frames: List[Frames]) -> Optional[float]: +def get_constant_duration(frames: list[Frames]) -> float | None: """ Returns the duration of a number of ScanSpec frames, if known and consistent. diff --git a/tests/test_boilerplate_removed.py b/tests/test_boilerplate_removed.py index 22949dcc..6736e80b 100644 --- a/tests/test_boilerplate_removed.py +++ b/tests/test_boilerplate_removed.py @@ -2,6 +2,7 @@ This file checks that all the example boilerplate text has been removed. It can be deleted when all the contained tests pass """ + from importlib.metadata import metadata from pathlib import Path diff --git a/tests/test_cli.py b/tests/test_cli.py index 9f3641fa..4268f017 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,7 +1,7 @@ import pathlib import subprocess import sys -from typing import List, cast +from typing import cast from unittest.mock import patch import matplotlib.pyplot as plt @@ -20,7 +20,7 @@ def assert_min_max_2d(line, xmin, xmax, ymin, ymax, length=None): assert len(line.get_data()[0]) == length mins = np.min(line.get_data(), axis=1) maxs = np.max(line.get_data(), axis=1) - assert list(zip(mins, maxs)) == [ + assert list(zip(mins, maxs, strict=False)) == [ pytest.approx([xmin, xmax]), pytest.approx([ymin, ymax]), ] @@ -31,7 +31,7 @@ def assert_min_max_3d(line, xmin, xmax, ymin, ymax, zmin, zmax, length=None): assert len(line.get_data_3d()[0]) == length mins = np.min(line.get_data_3d(), axis=1) maxs = np.max(line.get_data_3d(), axis=1) - assert list(zip(mins, maxs)) == [ + assert list(zip(mins, maxs, strict=False)) == [ pytest.approx([xmin, xmax]), pytest.approx([ymin, ymax]), pytest.approx([zmin, zmax]), @@ -61,7 +61,7 @@ def test_plot_1D_line() -> None: # End assert_min_max_2d(lines[3], 2.5, 2.5, 0, 0) # Arrows - texts = cast(List[Annotation], axes.texts) + texts = cast(list[Annotation], axes.texts) assert len(texts) == 1 assert tuple(texts[0].xy) == (0.5, 0) @@ -86,7 +86,7 @@ def test_plot_1D_line_snake_repeat() -> None: # End assert_min_max_2d(lines[4], 1, 1, 0, 0) # Arrows - texts = cast(List[Annotation], axes.texts) + texts = cast(list[Annotation], axes.texts) assert len(texts) == 2 assert tuple(texts[0].xy) == (1, 0) assert tuple(texts[1].xy) == pytest.approx([2, 0]) @@ -110,7 +110,7 @@ def test_plot_1D_step() -> None: # End assert_min_max_2d(lines[3], 2, 2, 0, 0) # Arrows - texts = cast(List[Annotation], axes.texts) + texts = cast(list[Annotation], axes.texts) assert len(texts) == 1 assert tuple(texts[0].xy) == (2, 0) @@ -137,7 +137,7 @@ def test_plot_2D_line() -> None: # End assert_min_max_2d(lines[6], 0.5, 0.5, 3, 3) # Arrows - texts = cast(List[Annotation], axes.texts) + texts = cast(list[Annotation], axes.texts) assert len(texts) == 2 assert tuple(texts[0].xy) == (0.5, 2) assert tuple(texts[1].xy) == pytest.approx([2.5, 3]) @@ -164,7 +164,7 @@ def test_plot_2D_line_rect_region() -> None: # End assert_min_max_2d(lines[5], 1.5, 1.5, 2, 2) # Arrows - texts = cast(List[Annotation], axes.texts) + texts = cast(list[Annotation], axes.texts) assert len(texts) == 2 assert tuple(texts[0].xy) == (-0.5, 1.5) assert tuple(texts[1].xy) == (-0.5, 2) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 560c694d..adb5729d 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,4 +1,5 @@ -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any import pytest from pydantic import ValidationError diff --git a/tests/test_specs.py b/tests/test_specs.py index 49b72497..dad0ffb6 100644 --- a/tests/test_specs.py +++ b/tests/test_specs.py @@ -1,5 +1,5 @@ import re -from typing import Any, Tuple +from typing import Any import pytest @@ -558,7 +558,7 @@ def test_multiple_statics_with_grid(): ), ], ) -def test_shape(spec: Spec, expected_shape: Tuple[int, ...]): +def test_shape(spec: Spec, expected_shape: tuple[int, ...]): assert expected_shape == spec.shape()