diff --git a/algobattle/problem.py b/algobattle/problem.py index 0fd3bd9b..65cabace 100644 --- a/algobattle/problem.py +++ b/algobattle/problem.py @@ -28,8 +28,13 @@ GetCoreSchemaHandler, ValidationInfo, ) +from pydantic.main import BaseModel from pydantic_core import CoreSchema -from pydantic_core.core_schema import with_info_after_validator_function +from pydantic_core.core_schema import ( + with_info_after_validator_function, + with_info_wrap_validator_function, + ValidatorFunctionWrapHandler, +) from algobattle.util import ( EncodableModel, @@ -503,20 +508,22 @@ class InstanceSolutionModel(EncodableModel): """Base class for Instance and solution models.""" @classmethod - def model_validate( # noqa: D102 - cls, - obj: Any, - *, - strict: bool | None = None, - from_attributes: bool | None = None, - context: dict[str, Any] | None = None, - ) -> Self: - model = super().model_validate(obj, strict=strict, from_attributes=from_attributes, context=context) - model_type = "instance" if issubclass(cls, InstanceModel) else "solution" + def __get_pydantic_core_schema__(cls, source: type[BaseModel], handler: GetCoreSchemaHandler) -> CoreSchema: + schema = handler(cls) + try: + model_type = "instance" if issubclass(cls, InstanceModel) else "solution" + except NameError: + return schema if cls._validate_with_self(model_type): - context = (context or {}) | {"self": model, model_type: model} - model = super().model_validate(obj, context=context) - return model + + def validate_with_self(input: object, validate: ValidatorFunctionWrapHandler, info: ValidationInfo) -> Self: + self = validate(input) + if info.context is None or "self" not in info.context: + self = cls.model_validate(input, context=(info.context or {}) | {"self": self, model_type: self}) + return self + + schema = with_info_wrap_validator_function(validate_with_self, schema) + return schema @classmethod def _annotation_needs_self(cls, annotation: object, model_type: ModelType) -> bool: