Skip to content

Commit

Permalink
Merge pull request #145 from ImogenBits/ref_annotations
Browse files Browse the repository at this point in the history
move ref annotation validation into pydantic schema
  • Loading branch information
Benezivas authored Nov 12, 2023
2 parents 8177971 + 507210b commit 8736216
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions algobattle/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,13 @@
GetCoreSchemaHandler,
ValidationInfo,
)
from pydantic.main import BaseModel
from pydantic_core import CoreSchema
from pydantic_core.core_schema import with_info_after_validator_function
from pydantic_core.core_schema import (
with_info_after_validator_function,
with_info_wrap_validator_function,
ValidatorFunctionWrapHandler,
)

from algobattle.util import (
EncodableModel,
Expand Down Expand Up @@ -503,20 +508,22 @@ class InstanceSolutionModel(EncodableModel):
"""Base class for Instance and solution models."""

@classmethod
def model_validate( # noqa: D102
cls,
obj: Any,
*,
strict: bool | None = None,
from_attributes: bool | None = None,
context: dict[str, Any] | None = None,
) -> Self:
model = super().model_validate(obj, strict=strict, from_attributes=from_attributes, context=context)
model_type = "instance" if issubclass(cls, InstanceModel) else "solution"
def __get_pydantic_core_schema__(cls, source: type[BaseModel], handler: GetCoreSchemaHandler) -> CoreSchema:
schema = handler(cls)
try:
model_type = "instance" if issubclass(cls, InstanceModel) else "solution"
except NameError:
return schema
if cls._validate_with_self(model_type):
context = (context or {}) | {"self": model, model_type: model}
model = super().model_validate(obj, context=context)
return model

def validate_with_self(input: object, validate: ValidatorFunctionWrapHandler, info: ValidationInfo) -> Self:
self = validate(input)
if info.context is None or "self" not in info.context:
self = cls.model_validate(input, context=(info.context or {}) | {"self": self, model_type: self})
return self

schema = with_info_wrap_validator_function(validate_with_self, schema)
return schema

@classmethod
def _annotation_needs_self(cls, annotation: object, model_type: ModelType) -> bool:
Expand Down

0 comments on commit 8736216

Please sign in to comment.