Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add schema support for asdf standard tags #32

Merged
merged 7 commits into from
Nov 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion asdf_pydantic/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,36 @@
from typing import Optional
"""

## Adding existing ASDF tags as a field
Type annotation must be added to the field to specify the ASDF tag to use in the
ASDF schema. There are a few options to do this:

- Use `AsdfTag` to specify the tag URI.
- Use `WithAsdfSchema` and pass in a dictionary to extend the schema with
additional properties. The key `"$ref"` can be used to specify the tag URI.

from asdf_pydantic import AsdfPydanticModel
from asdf_pydantic.schema import AsdfTag
from astropy.table import Table

class MyModel(AsdfPydanticModel):
table: Annotated[Table, AsdfTag("http://stsci.edu/schemas/asdf.org/table/table-1.1.0")]

For more customization of the ASDF schema output, you can use `WithAsdfSchema` to
extend the schema with additional properties.

# Changing the title of the field
table: Annotated[
Table,
WithAsdfSchema({
"title": "TABLE",
"$ref": "http://stsci.edu/schemas/asdf.org/table/table-1.1.0"
}),
]
"""

from typing import Literal, Optional

from pydantic import WithJsonSchema
from pydantic.json_schema import GenerateJsonSchema

DEFAULT_ASDF_SCHEMA_REF_TEMPLATE = "#/definitions/{model}"
Expand Down Expand Up @@ -60,3 +91,20 @@ def generate(self, schema, mode="validation"):
}

return json_schema


class WithAsdfSchema(WithJsonSchema):
def __init__(self, asdf_schema: dict, **kwargs):
super().__init__(asdf_schema, **kwargs)


def AsdfTag(tag: str, mode: Literal["auto", "ref", "tag"] = "auto") -> WithAsdfSchema:
if mode == "auto":
parsed_mode = "tag" if tag.startswith("tag") else "ref"
else:
parsed_mode = mode

if parsed_mode == "tag":
return WithAsdfSchema({"tag": tag})
else:
return WithAsdfSchema({"$ref": tag})
71 changes: 71 additions & 0 deletions tests/examples/test_astropy_tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from __future__ import annotations

from typing import Annotated

import asdf
import astropy.units as u
import pytest
import yaml
from asdf.extension import Extension
from astropy.table import Table
from astropy.units import Quantity

from asdf_pydantic import AsdfPydanticConverter, AsdfPydanticModel
from asdf_pydantic.schema import AsdfTag


class Database(AsdfPydanticModel):
_tag = "asdf://asdf-pydantic/examples/tags/database-1.0.0"
positions: Annotated[
Table, AsdfTag("http://stsci.edu/schemas/asdf.org/table/table-1.1.0")
]


@pytest.fixture()
def asdf_extension():
"""Registers an ASDF extension containing models for this test."""
AsdfPydanticConverter.add_models(Database)

class TestExtension(Extension):
extension_uri = "asdf://asdf-pydantic/examples/extensions/test-1.0.0"

converters = [AsdfPydanticConverter()] # type: ignore
tags = [*AsdfPydanticConverter().tags] # type: ignore

asdf.get_config().add_extension(TestExtension())

with asdf.config_context() as asdf_config:
asdf_config.add_resource_mapping(
{
yaml.safe_load(Database.model_asdf_schema())[
"id"
]: Database.model_asdf_schema()
}
)
print(Database.model_asdf_schema())
asdf_config.add_extension(TestExtension())
yield asdf_config


@pytest.mark.usefixtures("asdf_extension")
def test_convert_to_asdf(tmp_path):
database = Database(
positions=Table(
{
"x": Quantity([1, 2, 3], u.m),
"y": Quantity([4, 5, 6], u.m),
}
)
)
asdf.AsdfFile({"data": database}).write_to(tmp_path / "test.asdf")

with asdf.open(tmp_path / "test.asdf") as af:
assert isinstance(af.tree["data"], Database)
assert isinstance(af.tree["data"].positions, Table)


@pytest.mark.usefixtures("asdf_extension")
def test_check_schema():
"""Tests the model schema is correct."""
schema = yaml.safe_load(Database.model_asdf_schema())
asdf.schema.check_schema(schema)
31 changes: 31 additions & 0 deletions tests/schema_validation_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from tempfile import NamedTemporaryFile
from typing import Annotated

import asdf
import pydantic
Expand All @@ -7,6 +8,7 @@
from asdf.extension import Extension

from asdf_pydantic import AsdfPydanticConverter
from asdf_pydantic.model import AsdfPydanticModel
from tests.examples.shapes import AsdfRectangle
from tests.examples.tree import AsdfTreeNode

Expand Down Expand Up @@ -136,3 +138,32 @@ def test_given_child_field_contains_asdf_object_then_schema_has_child_tag():
child_schema = schema["definitions"]["AsdfNode"]["properties"]["child"]

assert {"tag": AsdfTreeNode._tag} in child_schema["anyOf"]


########################################################################################
# AsdfTag
########################################################################################
from asdf_pydantic.schema import AsdfTag # noqa: E402


@pytest.mark.parametrize(
"asdf_tag_str, mode, expected_ref_key",
[
("http://stsci.edu/schemas/asdf/unit/quantity-1.2.0", "auto", "$ref"),
("http://stsci.edu/schemas/asdf/unit/quantity-1.2.0", "ref", "$ref"),
("tag:stsci.edu:asdf/table/table-1.1.0", "auto", "tag"),
("tag:stsci.edu:asdf/table/table-1.1.0", "tag", "tag"),
],
)
def test_tag_mode(asdf_tag_str: str, mode, expected_ref_key):
"""Test that schema correctly has ``$ref:`` or ``tag:`` depending on the
selected mode.
"""
from astropy.table import Table

class TestModel(AsdfPydanticModel):
_tag = "asdf://asdf-pydantic/examples/tags/test-model-1.0.0"
table: Annotated[Table, AsdfTag(asdf_tag_str, mode=mode)]

schema = yaml.safe_load(TestModel.model_asdf_schema())
assert expected_ref_key in schema["properties"]["table"]