Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: 优化依赖注入在 pydantic v2 下的性能 #2870

Merged
merged 2 commits into from
Aug 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 67 additions & 31 deletions nonebot/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,20 @@
"""

from collections.abc import Generator
from functools import cached_property
from dataclasses import dataclass, is_dataclass
from typing_extensions import Self, get_args, get_origin, is_typeddict
from typing import (
TYPE_CHECKING,
Any,
Union,
Generic,
TypeVar,
Callable,
Optional,
Protocol,
Annotated,
overload,
)

from pydantic import VERSION, BaseModel
Expand Down Expand Up @@ -46,8 +49,8 @@ def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: ...
"DEFAULT_CONFIG",
"FieldInfo",
"ModelField",
"TypeAdapter",
"extract_field_info",
"model_field_validate",
"model_fields",
"model_config",
"model_dump",
Expand All @@ -63,9 +66,10 @@ def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: ...


if PYDANTIC_V2: # pragma: pydantic-v2
from pydantic import GetCoreSchemaHandler
from pydantic import TypeAdapter as TypeAdapter
from pydantic_core import CoreSchema, core_schema
from pydantic._internal._repr import display_as_type
from pydantic import TypeAdapter, GetCoreSchemaHandler
from pydantic.fields import FieldInfo as BaseFieldInfo

Required = Ellipsis
Expand Down Expand Up @@ -125,6 +129,25 @@ def construct(
"""Construct a ModelField from given infos."""
return cls._construct(name, annotation, field_info or FieldInfo())

def __hash__(self) -> int:
# Each ModelField is unique for our purposes,
# to allow store them in a set.
return id(self)

@cached_property
def type_adapter(self) -> TypeAdapter:
"""TypeAdapter of the field.

Cache the TypeAdapter to avoid creating it multiple times.
Pydantic v2 uses too much cpu time to create TypeAdapter.

See: https://github.com/pydantic/pydantic/issues/9834
"""
return TypeAdapter(
Annotated[self.annotation, self.field_info],
config=None if self._annotation_has_config() else DEFAULT_CONFIG,
)

def _annotation_has_config(self) -> bool:
"""Check if the annotation has config.

Expand Down Expand Up @@ -152,10 +175,9 @@ def _type_display(self):
"""Get the display of the type of the field."""
return display_as_type(self.annotation)

def __hash__(self) -> int:
# Each ModelField is unique for our purposes,
# to allow store them in a set.
return id(self)
def validate_value(self, value: Any) -> Any:
"""Validate the value pass to the field."""
return self.type_adapter.validate_python(value)

def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]:
"""Get FieldInfo init kwargs from a FieldInfo instance."""
Expand All @@ -164,15 +186,6 @@ def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]:
kwargs["annotation"] = field_info.rebuild_annotation()
return kwargs

def model_field_validate(
model_field: ModelField, value: Any, config: Optional[ConfigDict] = None
) -> Any:
"""Validate the value pass to the field."""
type: Any = Annotated[model_field.annotation, model_field.field_info]
return TypeAdapter(
type, config=None if model_field._annotation_has_config() else config
).validate_python(value)

def model_fields(model: type[BaseModel]) -> list[ModelField]:
"""Get field list of a model."""

Expand Down Expand Up @@ -305,6 +318,45 @@ def construct(
)
return cls._construct(name, annotation, field_info or FieldInfo())

def validate_value(self, value: Any) -> Any:
"""Validate the value pass to the field."""
v, errs_ = self.validate(value, {}, loc=())
if errs_:
raise ValueError(value, self)
return v

class TypeAdapter(Generic[T]):
@overload
def __init__(
self,
type: type[T],
*,
config: Optional[ConfigDict] = ...,
) -> None: ...

@overload
def __init__(
self,
type: Any,
*,
config: Optional[ConfigDict] = ...,
) -> None: ...

def __init__(
self,
type: Any,
*,
config: Optional[ConfigDict] = None,
) -> None:
self.type = type
self.config = config

def validate_python(self, value: Any) -> T:
return type_validate_python(self.type, value)

def validate_json(self, value: Union[str, bytes]) -> T:
return type_validate_json(self.type, value)

def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]:
"""Get FieldInfo init kwargs from a FieldInfo instance."""

Expand All @@ -314,22 +366,6 @@ def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]:
kwargs.update(field_info.extra)
return kwargs

def model_field_validate(
model_field: ModelField, value: Any, config: Optional[type[ConfigDict]] = None
) -> Any:
"""Validate the value pass to the field.

Set config before validate to ensure validate correctly.
"""

if model_field.model_config is not config:
model_field.set_config(config or ConfigDict)

v, errs_ = model_field.validate(value, {}, loc=())
if errs_:
raise ValueError(value, model_field)
return v

def model_fields(model: type[BaseModel]) -> list[ModelField]:
"""Get field list of a model."""

Expand Down
4 changes: 2 additions & 2 deletions nonebot/dependencies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

from loguru import logger

from nonebot.compat import ModelField
from nonebot.exception import TypeMisMatch
from nonebot.typing import evaluate_forwardref
from nonebot.compat import DEFAULT_CONFIG, ModelField, model_field_validate


def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
Expand Down Expand Up @@ -51,6 +51,6 @@ def check_field_type(field: ModelField, value: Any) -> Any:
"""检查字段类型是否匹配"""

try:
return model_field_validate(field, value, DEFAULT_CONFIG)
return field.validate_value(value)
except ValueError:
raise TypeMisMatch(field, value)
20 changes: 18 additions & 2 deletions tests/test_compat.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Any, Optional
from dataclasses import dataclass
from typing import Any, Optional, Annotated

import pytest
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError

from nonebot.compat import (
DEFAULT_CONFIG,
Required,
FieldInfo,
TypeAdapter,
PydanticUndefined,
model_dump,
custom_validation,
Expand All @@ -31,6 +32,21 @@ async def test_field_info():
assert FieldInfo(test="test").extra["test"] == "test"


@pytest.mark.asyncio
async def test_type_adapter():
t = TypeAdapter(Annotated[int, FieldInfo(ge=1)])

assert t.validate_python(2) == 2

with pytest.raises(ValidationError):
t.validate_python(0)

assert t.validate_json("2") == 2

with pytest.raises(ValidationError):
t.validate_json("0")


@pytest.mark.asyncio
async def test_model_dump():
class TestModel(BaseModel):
Expand Down