Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the need for uses_tagged_union deco #140

Merged
merged 1 commit into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 3 additions & 35 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,7 @@
from collections.abc import Callable, Iterable, Iterator, Sequence
from functools import lru_cache
from inspect import isclass
from typing import (
Any,
Generic,
Literal,
TypeVar,
get_origin,
get_type_hints,
)
from typing import Any, Generic, Literal, TypeVar

import numpy as np
from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler
Expand Down Expand Up @@ -140,24 +133,6 @@ def get_schema_of_union(cls, source_type: Any, handler: GetCoreSchemaHandler):
return super_cls


def uses_tagged_union(cls_or_func: T) -> T:
"""
T = TypeVar("T", type, Callable)
Decorator that processes the type hints of a class or function to detect and
register any tagged unions. If a tagged union is detected in the type hints,
it registers the class or function as a referrer to that tagged union.
Args:
cls_or_func (T): The class or function to be processed for tagged unions.
Returns:
T: The original class or function, unmodified.
"""
for v in get_type_hints(cls_or_func).values():
tagged_union = _tagged_unions.get(get_origin(v) or v, None)
if tagged_union:
tagged_union.add_reference(cls_or_func)
return cls_or_func


_tagged_unions: dict[type, _TaggedUnion] = {}


Expand All @@ -168,7 +143,6 @@ def __init__(self, base_class: type, discriminator: str):
self._discriminator = discriminator
# The members of the tagged union, i.e. subclasses of the baseclass
self._subclasses: list[type] = []
self._references: set[type | Callable] = set()

def add_member(self, cls: type):
if cls in self._subclasses:
Expand All @@ -177,14 +151,8 @@ def add_member(self, cls: type):
for member in self._subclasses:
if member is not cls:
_TaggedUnion._rebuild(member)
for ref in self._references:
_TaggedUnion._rebuild(ref)

def add_reference(self, cls_or_func: type | Callable):
self._references.add(cls_or_func)

@staticmethod
# https://github.com/bluesky/scanspec/issues/133
def _rebuild(cls_or_func: type | Callable):
if isclass(cls_or_func):
if is_pydantic_dataclass(cls_or_func):
Expand All @@ -194,14 +162,14 @@ def _rebuild(cls_or_func: type | Callable):

def schema(self, handler: GetCoreSchemaHandler) -> CoreSchema:
return tagged_union_schema(
make_schema(tuple(self._subclasses), handler),
_make_schema(tuple(self._subclasses), handler),
discriminator=self._discriminator,
ref=self._base_class.__name__,
)


@lru_cache(1)
def make_schema(members: tuple[type, ...], handler):
def _make_schema(members: tuple[type, ...], handler):
return {member.__name__: handler(member) for member in members}


Expand Down
4 changes: 1 addition & 3 deletions src/scanspec/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pydantic import Field
from pydantic.dataclasses import dataclass

from scanspec.core import AxesPoints, Frames, Path, uses_tagged_union
from scanspec.core import AxesPoints, Frames, Path

from .specs import Line, Spec

Expand All @@ -25,7 +25,6 @@
Points = str | list[float]


@uses_tagged_union
@dataclass
class ValidResponse:
"""Response model for spec validation."""
Expand All @@ -42,7 +41,6 @@ class PointsFormat(str, Enum):
BASE64_ENCODED = "BASE64_ENCODED"


@uses_tagged_union
@dataclass
class PointsRequest:
"""A request for generated scan points."""
Expand Down
16 changes: 0 additions & 16 deletions tests/test_basemodel.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import pytest
from pydantic import BaseModel, TypeAdapter
from pydantic.dataclasses import dataclass

from scanspec.core import StrictConfig, uses_tagged_union
from scanspec.specs import Line, Spec


@uses_tagged_union
class Foo(BaseModel):
spec: Spec

Expand Down Expand Up @@ -41,16 +38,3 @@ def test_type_adapter(model: Foo):
as_json = model.model_dump_json()
deserialized = type_adapter.validate_json(as_json)
assert deserialized == model


def test_schema_updates_with_new_values():
old_schema = TypeAdapter(Foo).json_schema()

@dataclass(config=StrictConfig)
class Splat(Spec[str]): # NOSONAR
def axes(self) -> list[str]:
return ["*"]

new_schema = TypeAdapter(Foo).json_schema()

assert new_schema != old_schema