From 92dff902b1bdfcdc89a21751dfa592acb217c84b Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Sun, 11 Aug 2024 06:10:18 +0000 Subject: [PATCH 1/2] :zap: improve pydantic v2 performance --- nonebot/compat.py | 98 ++++++++++++++++++++++++----------- nonebot/dependencies/utils.py | 4 +- tests/test_compat.py | 20 ++++++- 3 files changed, 87 insertions(+), 35 deletions(-) diff --git a/nonebot/compat.py b/nonebot/compat.py index b28eedb58b41..0f90e5c2207d 100644 --- a/nonebot/compat.py +++ b/nonebot/compat.py @@ -8,17 +8,20 @@ """ from collections.abc import Generator +from functools import cached_property from dataclasses import dataclass, is_dataclass from typing_extensions import Self, get_args, get_origin, is_typeddict from typing import ( TYPE_CHECKING, Any, Union, + Generic, TypeVar, Callable, Optional, Protocol, Annotated, + overload, ) from pydantic import VERSION, BaseModel @@ -46,8 +49,8 @@ def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: ... "DEFAULT_CONFIG", "FieldInfo", "ModelField", + "TypeAdapter", "extract_field_info", - "model_field_validate", "model_fields", "model_config", "model_dump", @@ -63,9 +66,10 @@ def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: ... if PYDANTIC_V2: # pragma: pydantic-v2 + from pydantic import GetCoreSchemaHandler + from pydantic import TypeAdapter as TypeAdapter from pydantic_core import CoreSchema, core_schema from pydantic._internal._repr import display_as_type - from pydantic import TypeAdapter, GetCoreSchemaHandler from pydantic.fields import FieldInfo as BaseFieldInfo Required = Ellipsis @@ -125,6 +129,25 @@ def construct( """Construct a ModelField from given infos.""" return cls._construct(name, annotation, field_info or FieldInfo()) + def __hash__(self) -> int: + # Each ModelField is unique for our purposes, + # to allow store them in a set. + return id(self) + + @cached_property + def type_adapter(self) -> TypeAdapter: + """TypeAdapter of the field. + + Cache the TypeAdapter to avoid creating it multiple times. + Pydantic v2 uses too much cpu time to create TypeAdapter. + + See: https://github.com/pydantic/pydantic/issues/9834 + """ + return TypeAdapter( + Annotated[self.annotation, self.field_info], + config=None if self._annotation_has_config() else DEFAULT_CONFIG, + ) + def _annotation_has_config(self) -> bool: """Check if the annotation has config. @@ -152,10 +175,9 @@ def _type_display(self): """Get the display of the type of the field.""" return display_as_type(self.annotation) - def __hash__(self) -> int: - # Each ModelField is unique for our purposes, - # to allow store them in a set. - return id(self) + def validate(self, value: Any) -> Any: + """Validate the value pass to the field.""" + return self.type_adapter.validate_python(value) def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]: """Get FieldInfo init kwargs from a FieldInfo instance.""" @@ -164,15 +186,6 @@ def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]: kwargs["annotation"] = field_info.rebuild_annotation() return kwargs - def model_field_validate( - model_field: ModelField, value: Any, config: Optional[ConfigDict] = None - ) -> Any: - """Validate the value pass to the field.""" - type: Any = Annotated[model_field.annotation, model_field.field_info] - return TypeAdapter( - type, config=None if model_field._annotation_has_config() else config - ).validate_python(value) - def model_fields(model: type[BaseModel]) -> list[ModelField]: """Get field list of a model.""" @@ -305,6 +318,45 @@ def construct( ) return cls._construct(name, annotation, field_info or FieldInfo()) + def validate(self, value: Any) -> Any: + """Validate the value pass to the field.""" + v, errs_ = super().validate(value, {}, loc=()) + if errs_: + raise ValueError(value, self) + return v + + class TypeAdapter(Generic[T]): + @overload + def __init__( + self, + type: type[T], + *, + config: Optional[ConfigDict] = ..., + ) -> None: ... + + @overload + def __init__( + self, + type: Any, + *, + config: Optional[ConfigDict] = ..., + ) -> None: ... + + def __init__( + self, + type: Any, + *, + config: Optional[ConfigDict] = None, + ) -> None: + self.type = type + self.config = config + + def validate_python(self, value: Any) -> T: + return type_validate_python(self.type, value) + + def validate_json(self, value: Union[str, bytes]) -> T: + return type_validate_json(self.type, value) + def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]: """Get FieldInfo init kwargs from a FieldInfo instance.""" @@ -314,22 +366,6 @@ def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]: kwargs.update(field_info.extra) return kwargs - def model_field_validate( - model_field: ModelField, value: Any, config: Optional[type[ConfigDict]] = None - ) -> Any: - """Validate the value pass to the field. - - Set config before validate to ensure validate correctly. - """ - - if model_field.model_config is not config: - model_field.set_config(config or ConfigDict) - - v, errs_ = model_field.validate(value, {}, loc=()) - if errs_: - raise ValueError(value, model_field) - return v - def model_fields(model: type[BaseModel]) -> list[ModelField]: """Get field list of a model.""" diff --git a/nonebot/dependencies/utils.py b/nonebot/dependencies/utils.py index 0852683f1c8f..dfa6e361e097 100644 --- a/nonebot/dependencies/utils.py +++ b/nonebot/dependencies/utils.py @@ -9,9 +9,9 @@ from loguru import logger +from nonebot.compat import ModelField from nonebot.exception import TypeMisMatch from nonebot.typing import evaluate_forwardref -from nonebot.compat import DEFAULT_CONFIG, ModelField, model_field_validate def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: @@ -51,6 +51,6 @@ def check_field_type(field: ModelField, value: Any) -> Any: """检查字段类型是否匹配""" try: - return model_field_validate(field, value, DEFAULT_CONFIG) + return field.validate(value) except ValueError: raise TypeMisMatch(field, value) diff --git a/tests/test_compat.py b/tests/test_compat.py index a50da686890b..151f250e165e 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -1,13 +1,14 @@ -from typing import Any, Optional from dataclasses import dataclass +from typing import Any, Optional, Annotated import pytest -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from nonebot.compat import ( DEFAULT_CONFIG, Required, FieldInfo, + TypeAdapter, PydanticUndefined, model_dump, custom_validation, @@ -31,6 +32,21 @@ async def test_field_info(): assert FieldInfo(test="test").extra["test"] == "test" +@pytest.mark.asyncio +async def test_type_adapter(): + t = TypeAdapter(Annotated[int, FieldInfo(ge=1)]) + + assert t.validate_python(2) == 2 + + with pytest.raises(ValidationError): + t.validate_python(0) + + assert t.validate_json("2") == 2 + + with pytest.raises(ValidationError): + t.validate_json("0") + + @pytest.mark.asyncio async def test_model_dump(): class TestModel(BaseModel): From da6526864cb9fa0427ddc22d6e241a2d81af2257 Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Sun, 11 Aug 2024 06:42:08 +0000 Subject: [PATCH 2/2] :bug: remove method override --- nonebot/compat.py | 6 +++--- nonebot/dependencies/utils.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/nonebot/compat.py b/nonebot/compat.py index 0f90e5c2207d..8b73aef9b440 100644 --- a/nonebot/compat.py +++ b/nonebot/compat.py @@ -175,7 +175,7 @@ def _type_display(self): """Get the display of the type of the field.""" return display_as_type(self.annotation) - def validate(self, value: Any) -> Any: + def validate_value(self, value: Any) -> Any: """Validate the value pass to the field.""" return self.type_adapter.validate_python(value) @@ -318,9 +318,9 @@ def construct( ) return cls._construct(name, annotation, field_info or FieldInfo()) - def validate(self, value: Any) -> Any: + def validate_value(self, value: Any) -> Any: """Validate the value pass to the field.""" - v, errs_ = super().validate(value, {}, loc=()) + v, errs_ = self.validate(value, {}, loc=()) if errs_: raise ValueError(value, self) return v diff --git a/nonebot/dependencies/utils.py b/nonebot/dependencies/utils.py index dfa6e361e097..55f490ccd7bd 100644 --- a/nonebot/dependencies/utils.py +++ b/nonebot/dependencies/utils.py @@ -51,6 +51,6 @@ def check_field_type(field: ModelField, value: Any) -> Any: """检查字段类型是否匹配""" try: - return field.validate(value) + return field.validate_value(value) except ValueError: raise TypeMisMatch(field, value)