From 9e6e43657e6f40154f214fc5c0a8ac4e092e5fb3 Mon Sep 17 00:00:00 2001 From: Tom Cobb Date: Tue, 27 Aug 2024 15:49:30 +0000 Subject: [PATCH] Remove the need for uses_tagged_union deco At present we support people adding to the list of Specs after importing the service. This requires a uses_tagged_union decorator on any BaseModel that references a Spec. We cannot currently think of a use case for Specs being implemented outside of scanspec.specs, so removing this requirement means we can ditch the decorator. It also makes the docs (next PR) easier to write: - Put your specs in src/scanspec/specs.py - Release a new version - Update both blueapi and the scanspec service to use that new version --- src/scanspec/core.py | 38 +++----------------------------------- src/scanspec/service.py | 4 +--- tests/test_basemodel.py | 16 ---------------- 3 files changed, 4 insertions(+), 54 deletions(-) diff --git a/src/scanspec/core.py b/src/scanspec/core.py index 6cb62d47..a0ea6169 100644 --- a/src/scanspec/core.py +++ b/src/scanspec/core.py @@ -3,14 +3,7 @@ from collections.abc import Callable, Iterable, Iterator, Sequence from functools import lru_cache from inspect import isclass -from typing import ( - Any, - Generic, - Literal, - TypeVar, - get_origin, - get_type_hints, -) +from typing import Any, Generic, Literal, TypeVar import numpy as np from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler @@ -140,24 +133,6 @@ def get_schema_of_union(cls, source_type: Any, handler: GetCoreSchemaHandler): return super_cls -def uses_tagged_union(cls_or_func: T) -> T: - """ - T = TypeVar("T", type, Callable) - Decorator that processes the type hints of a class or function to detect and - register any tagged unions. If a tagged union is detected in the type hints, - it registers the class or function as a referrer to that tagged union. - Args: - cls_or_func (T): The class or function to be processed for tagged unions. - Returns: - T: The original class or function, unmodified. - """ - for v in get_type_hints(cls_or_func).values(): - tagged_union = _tagged_unions.get(get_origin(v) or v, None) - if tagged_union: - tagged_union.add_reference(cls_or_func) - return cls_or_func - - _tagged_unions: dict[type, _TaggedUnion] = {} @@ -168,7 +143,6 @@ def __init__(self, base_class: type, discriminator: str): self._discriminator = discriminator # The members of the tagged union, i.e. subclasses of the baseclass self._subclasses: list[type] = [] - self._references: set[type | Callable] = set() def add_member(self, cls: type): if cls in self._subclasses: @@ -177,14 +151,8 @@ def add_member(self, cls: type): for member in self._subclasses: if member is not cls: _TaggedUnion._rebuild(member) - for ref in self._references: - _TaggedUnion._rebuild(ref) - - def add_reference(self, cls_or_func: type | Callable): - self._references.add(cls_or_func) @staticmethod - # https://github.com/bluesky/scanspec/issues/133 def _rebuild(cls_or_func: type | Callable): if isclass(cls_or_func): if is_pydantic_dataclass(cls_or_func): @@ -194,14 +162,14 @@ def _rebuild(cls_or_func: type | Callable): def schema(self, handler: GetCoreSchemaHandler) -> CoreSchema: return tagged_union_schema( - make_schema(tuple(self._subclasses), handler), + _make_schema(tuple(self._subclasses), handler), discriminator=self._discriminator, ref=self._base_class.__name__, ) @lru_cache(1) -def make_schema(members: tuple[type, ...], handler): +def _make_schema(members: tuple[type, ...], handler): return {member.__name__: handler(member) for member in members} diff --git a/src/scanspec/service.py b/src/scanspec/service.py index b05e4113..52121833 100644 --- a/src/scanspec/service.py +++ b/src/scanspec/service.py @@ -10,7 +10,7 @@ from pydantic import Field from pydantic.dataclasses import dataclass -from scanspec.core import AxesPoints, Frames, Path, uses_tagged_union +from scanspec.core import AxesPoints, Frames, Path from .specs import Line, Spec @@ -25,7 +25,6 @@ Points = str | list[float] -@uses_tagged_union @dataclass class ValidResponse: """Response model for spec validation.""" @@ -42,7 +41,6 @@ class PointsFormat(str, Enum): BASE64_ENCODED = "BASE64_ENCODED" -@uses_tagged_union @dataclass class PointsRequest: """A request for generated scan points.""" diff --git a/tests/test_basemodel.py b/tests/test_basemodel.py index 85d83741..cb809747 100644 --- a/tests/test_basemodel.py +++ b/tests/test_basemodel.py @@ -1,12 +1,9 @@ import pytest from pydantic import BaseModel, TypeAdapter -from pydantic.dataclasses import dataclass -from scanspec.core import StrictConfig, uses_tagged_union from scanspec.specs import Line, Spec -@uses_tagged_union class Foo(BaseModel): spec: Spec @@ -41,16 +38,3 @@ def test_type_adapter(model: Foo): as_json = model.model_dump_json() deserialized = type_adapter.validate_json(as_json) assert deserialized == model - - -def test_schema_updates_with_new_values(): - old_schema = TypeAdapter(Foo).json_schema() - - @dataclass(config=StrictConfig) - class Splat(Spec[str]): # NOSONAR - def axes(self) -> list[str]: - return ["*"] - - new_schema = TypeAdapter(Foo).json_schema() - - assert new_schema != old_schema