Skip to content

Commit

Permalink
move to pydantic v2 (#127)
Browse files Browse the repository at this point in the history
Refactored discriminated union of subclasses
  • Loading branch information
ZohebShaikh authored Aug 7, 2024
1 parent c412930 commit f49320c
Show file tree
Hide file tree
Showing 9 changed files with 4,029 additions and 107 deletions.
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ classifiers = [
"Programming Language :: Python :: 3.11",
]
description = "Specify step and flyscan paths in a serializable, efficient and Pythonic way"
dependencies = ["numpy>=2", "click>=8.1", "pydantic<2.0", "httpx==0.26.0"]
dependencies = [
"numpy>=2",
"click>=8.1",
"pydantic>=2.0",
]
dynamic = ["version"]
license.file = "LICENSE"
readme = "README.md"
Expand All @@ -21,7 +25,7 @@ requires-python = ">=3.10"
# Plotting
plotting = ["scipy", "matplotlib"]
# REST service support
service = ["fastapi==0.99", "uvicorn"]
service = ["fastapi>=0.100.0", "uvicorn"]
# For development tests/docs
dev = [
# This syntax is supported since pip 21.2
Expand Down
3,806 changes: 3,805 additions & 1 deletion schema.json

Large diffs are not rendered by default.

216 changes: 139 additions & 77 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
from __future__ import annotations

import dataclasses
from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import field
from typing import Any, Generic, Literal, TypeVar, Union
from functools import partial
from inspect import isclass
from typing import (
Any,
Generic,
Literal,
TypeVar,
Union,
get_origin,
get_type_hints,
)

import numpy as np
from pydantic import BaseConfig, Extra, Field, ValidationError, create_model
from pydantic.error_wrappers import ErrorWrapper
from pydantic import (
ConfigDict,
Field,
GetCoreSchemaHandler,
TypeAdapter,
)
from pydantic.dataclasses import rebuild_dataclass
from pydantic.fields import FieldInfo

__all__ = [
"if_instance_do",
Expand All @@ -23,24 +39,18 @@
]


class StrictConfig(BaseConfig):
"""Pydantic configuration for scanspecs and regions."""

extra: Extra = Extra.forbid
StrictConfig: ConfigDict = {"extra": "forbid"}


def discriminated_union_of_subclasses(
super_cls: type | None = None,
*,
cls,
discriminator: str = "type",
config: type[BaseConfig] | None = None,
) -> type | Callable[[type], type]:
):
"""Add all subclasses of super_cls to a discriminated union.
For all subclasses of super_cls, add a discriminator field to identify
the type. Raw JSON should look like {"type": <type name>, params for
<type name>...}.
Add validation methods to super_cls so it can be parsed by pydantic.parse_obj_as.
Example::
Expand Down Expand Up @@ -104,72 +114,124 @@ def calculate(self) -> int:
Type | Callable[[Type], Type]: A decorator that adds the necessary
functionality to a class.
"""
tagged_union = _TaggedUnion(cls, discriminator)
_tagged_unions[cls] = tagged_union
cls.__init_subclass__ = classmethod(partial(__init_subclass__, discriminator))
cls.__get_pydantic_core_schema__ = classmethod(
partial(__get_pydantic_core_schema__, tagged_union=tagged_union)
)
return cls

def wrap(cls):
return _discriminated_union_of_subclasses(cls, discriminator, config)

# Work out if the call was @discriminated_union_of_subclasses or
# @discriminated_union_of_subclasses(...)
if super_cls is None:
return wrap
else:
return wrap(super_cls)


def _discriminated_union_of_subclasses(
super_cls: type,
discriminator: str,
config: type[BaseConfig] | None = None,
) -> type | Callable[[type], type]:
super_cls._ref_classes = set() # type: ignore
super_cls._model = None # type: ignore

def __init_subclass__(cls) -> None:
# Keep track of inherting classes in super class
cls._ref_classes.add(cls)

# Add a discriminator field to the class so it can
# be identified when deserailizing.
cls.__annotations__ = {
**cls.__annotations__,
discriminator: Literal[cls.__name__],
}
setattr(cls, discriminator, field(default=cls.__name__, repr=False))

def __get_validators__(cls) -> Any:
yield cls.__validate__

def __validate__(cls, v: Any) -> Any:
# Lazily initialize model on first use because this
# needs to be done once, after all subclasses have been
# declared
if cls._model is None:
root = Union[tuple(cls._ref_classes)] # type: ignore # noqa
cls._model = create_model(
super_cls.__name__,
__root__=(root, Field(..., discriminator=discriminator)),
__config__=config,
)

try:
return cls._model(__root__=v).__root__
except ValidationError as e:
for (
error
) in e.raw_errors: # need in to remove redundant __root__ from error path
if (
isinstance(error, ErrorWrapper)
and error.loc_tuple()[0] == "__root__"
):
error._loc = error.loc_tuple()[1:]

raise e

# Inject magic methods into super_cls
for method in __init_subclass__, __get_validators__, __validate__:
setattr(super_cls, method.__name__, classmethod(method)) # type: ignore

return super_cls
T = TypeVar("T", type, Callable)


def deserialize_as(cls, obj):
return _tagged_unions[cls].type_adapter.validate_python(obj)


def uses_tagged_union(cls_or_func: T) -> T:
"""
Decorator that processes the type hints of a class or function to detect and
register any tagged unions. If a tagged union is detected in the type hints,
it registers the class or function as a referrer to that tagged union.
Args:
cls_or_func (T): The class or function to be processed for tagged unions.
Returns:
T: The original class or function, unmodified.
"""
for k, v in get_type_hints(cls_or_func).items():
tagged_union = _tagged_unions.get(get_origin(v) or v, None)
if tagged_union:
tagged_union.add_referrer(cls_or_func, k)
return cls_or_func


class _TaggedUnion:
def __init__(self, base_class: type, discriminator: str):
self._base_class = base_class
# The members of the tagged union, i.e. subclasses of the baseclasses
self._members: list[type] = []
# Classes and their field names that refer to this tagged union
self._referrers: dict[type | Callable, set[str]] = {}
self.type_adapter: TypeAdapter = TypeAdapter(None)
self._discriminator = discriminator

def _make_union(self):
if len(self._members) > 0:
return Union[tuple(self._members)] # type: ignore # noqa

def _set_discriminator(self, cls: type | Callable, field_name: str, field: Any):
# Set the field to use the `type` discriminator on deserialize
# https://docs.pydantic.dev/2.8/concepts/unions/#discriminated-unions-with-str-discriminators
if isclass(cls):
assert isinstance(
field, FieldInfo
), f"Expected {cls.__name__}.{field_name} to be a Pydantic field, not {field!r}" # noqa: E501
field.discriminator = self._discriminator

def add_member(self, cls: type):
if cls in self._members:
# A side effect of hooking to __get_pydantic_core_schema__ is that it is
# called muliple times for the same member, do no process if it wouldn't
# change the member list
return

self._members.append(cls)
union = self._make_union()
if union:
# There are more than 1 subclasses in the union, so set all the referrers
# to use this union
for referrer, fields in self._referrers.items():
if isclass(referrer):
for field in dataclasses.fields(referrer):
if field.name in fields:
field.type = union
self._set_discriminator(referrer, field.name, field.default)
rebuild_dataclass(referrer, force=True)
# Make a type adapter for use in deserialization
self.type_adapter = TypeAdapter(union)

def add_referrer(self, cls: type | Callable, attr_name: str):
self._referrers.setdefault(cls, set()).add(attr_name)
union = self._make_union()
if union:
# There are more than 1 subclasses in the union, so set the referrer
# (which is currently being constructed) to use it
# note that we use annotations as the class has not been turned into
# a dataclass yet
cls.__annotations__[attr_name] = union
self._set_discriminator(cls, attr_name, getattr(cls, attr_name, None))


_tagged_unions: dict[type, _TaggedUnion] = {}


def __init_subclass__(discriminator: str, cls: type):
# Add a discriminator field to the class so it can
# be identified when deserailizing, and make sure it is last in the list
cls.__annotations__ = {
**cls.__annotations__,
discriminator: Literal[cls.__name__], # type: ignore
}
cls.type = Field(cls.__name__, repr=False) # type: ignore
# Replace any bare annotation with a discriminated union of subclasses
# and register this class as one that refers to that union so it can be updated
for k, v in get_type_hints(cls).items():
# This works for Expression[T] or Expression
tagged_union = _tagged_unions.get(get_origin(v) or v, None)
if tagged_union:
tagged_union.add_referrer(cls, k)


def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler, tagged_union: _TaggedUnion
):
# Rebuild any dataclass (including this one) that references this union
# Note that this has to be done after the creation of the dataclass so that
# previously created classes can refer to this newly created class
tagged_union.add_member(cls)
return handler(source_type)


def if_instance_do(x: Any, cls: type, func: Callable):
Expand Down
3 changes: 2 additions & 1 deletion src/scanspec/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mpl_toolkits.mplot3d import Axes3D, proj3d
from scipy import interpolate

from .core import Path
from .core import Path, uses_tagged_union
from .regions import Circle, Ellipse, Polygon, Rectangle, Region, find_regions
from .specs import DURATION, Spec

Expand Down Expand Up @@ -86,6 +86,7 @@ def _plot_spline(axes, ranges, arrays: list[np.ndarray], index_colours: dict[int
yield unscaled_splines


@uses_tagged_union
def plot_spec(spec: Spec[Any], title: str | None = None):
"""Plot a spec, drawing the path taken through the scan.
Expand Down
31 changes: 22 additions & 9 deletions src/scanspec/regions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from collections.abc import Iterator
from typing import Generic
from collections.abc import Iterator, Mapping
from dataclasses import asdict, is_dataclass
from typing import Any, Generic

import numpy as np
from pydantic import BaseModel, Field
Expand All @@ -11,6 +12,7 @@
AxesPoints,
Axis,
StrictConfig,
deserialize_as,
discriminated_union_of_subclasses,
if_instance_do,
)
Expand Down Expand Up @@ -64,6 +66,15 @@ def __sub__(self, other) -> DifferenceOf[Axis]:
def __xor__(self, other) -> SymmetricDifferenceOf[Axis]:
return if_instance_do(other, Region, lambda o: SymmetricDifferenceOf(self, o))

def serialize(self) -> Mapping[str, Any]:
"""Serialize the Region to a dictionary."""
return asdict(self) # type: ignore

@staticmethod
def deserialize(obj):
"""Deserialize the Region from a dictionary."""
return deserialize_as(Region, obj)


def get_mask(region: Region[Axis], points: AxesPoints[Axis]) -> np.ndarray:
"""Return a mask of the points inside the region.
Expand Down Expand Up @@ -254,10 +265,10 @@ class Polygon(Region[Axis]):
x_axis: Axis = Field(description="The name matching the x axis of the spec")
y_axis: Axis = Field(description="The name matching the y axis of the spec")
x_verts: list[float] = Field(
description="The Nx1 x coordinates of the polygons vertices", min_len=3
description="The Nx1 x coordinates of the polygons vertices", min_length=3
)
y_verts: list[float] = Field(
description="The Nx1 y coordinates of the polygons vertices", min_len=3
description="The Nx1 y coordinates of the polygons vertices", min_length=3
)

def axis_sets(self) -> list[set[Axis]]:
Expand Down Expand Up @@ -298,7 +309,7 @@ class Circle(Region[Axis]):
y_axis: Axis = Field(description="The name matching the y axis of the spec")
x_middle: float = Field(description="The central x point of the circle")
y_middle: float = Field(description="The central y point of the circle")
radius: float = Field(description="Radius of the circle", exc_min=0)
radius: float = Field(description="Radius of the circle", gt=0)

def axis_sets(self) -> list[set[Axis]]:
return [{self.x_axis, self.y_axis}]
Expand Down Expand Up @@ -328,10 +339,10 @@ class Ellipse(Region[Axis]):
x_middle: float = Field(description="The central x point of the ellipse")
y_middle: float = Field(description="The central y point of the ellipse")
x_radius: float = Field(
description="The radius along the x axis of the ellipse", exc_min=0
description="The radius along the x axis of the ellipse", gt=0
)
y_radius: float = Field(
description="The radius along the y axis of the ellipse", exc_min=0
description="The radius along the y axis of the ellipse", gt=0
)
angle: float = Field(description="The angle of the ellipse (degrees)", default=0.0)

Expand All @@ -354,8 +365,10 @@ def mask(self, points: AxesPoints[Axis]) -> np.ndarray:

def find_regions(obj) -> Iterator[Region[Axis]]:
"""Recursively yield Regions from obj and its children."""
if hasattr(obj, "__pydantic_model__") and issubclass(
obj.__pydantic_model__, BaseModel
if (
hasattr(obj, "__pydantic_model__")
and issubclass(obj.__pydantic_model__, BaseModel)
or is_dataclass(obj)
):
if isinstance(obj, Region):
yield obj
Expand Down
Loading

0 comments on commit f49320c

Please sign in to comment.